mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-26 13:31:00 +00:00
net/aws: enqueue transcribed buffers within the ws loop
Instead of sending transcription events to the src pad loop, this commit enqueues the transcribed buffers immediately in the ws loop, then notifies the src pad loop. The src pad loop is only in charge of dequeuing the buffers. This should help with upcoming evolutions. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1104>
This commit is contained in:
parent
00153754bb
commit
36ae29d746
1 changed files with 100 additions and 107 deletions
|
@ -66,7 +66,7 @@ struct Settings {
|
||||||
vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
|
vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::default::Default for Settings {
|
impl Default for Settings {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
latency: DEFAULT_LATENCY,
|
latency: DEFAULT_LATENCY,
|
||||||
|
@ -112,26 +112,25 @@ impl TranscriptionSettings {
|
||||||
struct State {
|
struct State {
|
||||||
client: Option<aws_transcribe::Client>,
|
client: Option<aws_transcribe::Client>,
|
||||||
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
|
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
|
||||||
transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>,
|
transcript_notif_tx: Option<mpsc::Sender<()>>,
|
||||||
ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
||||||
in_segment: gst::FormattedSegment<gst::ClockTime>,
|
in_segment: gst::FormattedSegment<gst::ClockTime>,
|
||||||
out_segment: gst::FormattedSegment<gst::ClockTime>,
|
out_segment: gst::FormattedSegment<gst::ClockTime>,
|
||||||
seqnum: gst::Seqnum,
|
seqnum: gst::Seqnum,
|
||||||
buffers: VecDeque<gst::Buffer>,
|
buffers: VecDeque<gst::Buffer>,
|
||||||
send_eos: bool,
|
send_eos: bool,
|
||||||
// FIXME never set to true
|
|
||||||
discont: bool,
|
discont: bool,
|
||||||
partial_index: usize,
|
partial_index: usize,
|
||||||
send_events: bool,
|
send_events: bool,
|
||||||
start_time: Option<gst::ClockTime>,
|
start_time: Option<gst::ClockTime>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::default::Default for State {
|
impl Default for State {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
client: None,
|
client: None,
|
||||||
buffer_tx: None,
|
buffer_tx: None,
|
||||||
transcript_tx: None,
|
transcript_notif_tx: None,
|
||||||
ws_loop_handle: None,
|
ws_loop_handle: None,
|
||||||
in_segment: gst::FormattedSegment::new(),
|
in_segment: gst::FormattedSegment::new(),
|
||||||
out_segment: gst::FormattedSegment::new(),
|
out_segment: gst::FormattedSegment::new(),
|
||||||
|
@ -297,8 +296,11 @@ impl Transcriber {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn enqueue(&self, state: &mut State, items: &[model::Item], partial: bool) {
|
/// Enqueues a buffer for each of the provided stable items.
|
||||||
let lateness = self.settings.lock().unwrap().lateness;
|
///
|
||||||
|
/// Returns `true` if at least one buffer was enqueued.
|
||||||
|
fn enqueue(&self, items: &[model::Item], partial: bool, lateness: gst::ClockTime) -> bool {
|
||||||
|
let mut state = self.state.lock().unwrap();
|
||||||
|
|
||||||
if items.len() <= state.partial_index {
|
if items.len() <= state.partial_index {
|
||||||
gst::error!(
|
gst::error!(
|
||||||
|
@ -313,53 +315,55 @@ impl Transcriber {
|
||||||
state.partial_index = 0;
|
state.partial_index = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for item in &items[state.partial_index..] {
|
let mut enqueued = false;
|
||||||
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
|
||||||
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
|
||||||
|
|
||||||
|
for item in &items[state.partial_index..] {
|
||||||
if !item.stable().unwrap_or(false) {
|
if !item.stable().unwrap_or(false) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME could probably just unwrap
|
let Some(content) = item.content() else { continue };
|
||||||
if let Some(content) = item.content() {
|
|
||||||
/* Should be sent now */
|
|
||||||
gst::debug!(
|
|
||||||
CAT,
|
|
||||||
imp: self,
|
|
||||||
"Item is ready for queuing: {content}, PTS {start_time}",
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes());
|
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
||||||
{
|
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
||||||
let buf = buf.get_mut().unwrap();
|
|
||||||
|
|
||||||
if state.discont {
|
/* Should be sent now */
|
||||||
buf.set_flags(gst::BufferFlags::DISCONT);
|
gst::debug!(
|
||||||
state.discont = false;
|
CAT,
|
||||||
}
|
imp: self,
|
||||||
|
"Item is ready for queuing: {content}, PTS {start_time}",
|
||||||
|
);
|
||||||
|
|
||||||
buf.set_pts(start_time);
|
let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes());
|
||||||
buf.set_duration(end_time - start_time);
|
{
|
||||||
|
let buf = buf.get_mut().unwrap();
|
||||||
|
|
||||||
|
if state.discont {
|
||||||
|
buf.set_flags(gst::BufferFlags::DISCONT);
|
||||||
|
state.discont = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
state.partial_index += 1;
|
buf.set_pts(start_time);
|
||||||
|
buf.set_duration(end_time - start_time);
|
||||||
state.buffers.push_back(buf);
|
|
||||||
} else {
|
|
||||||
gst::debug!(CAT, imp: self, "None transcript item content");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state.partial_index += 1;
|
||||||
|
|
||||||
|
state.buffers.push_back(buf);
|
||||||
|
enqueued = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if !partial {
|
if !partial {
|
||||||
state.partial_index = 0;
|
state.partial_index = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enqueued
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pad_loop_fn(&self, receiver: &mut mpsc::Receiver<model::TranscriptEvent>) -> Result<(), ()> {
|
fn pad_loop_fn(&self, transcript_notif_rx: &mut mpsc::Receiver<()>) {
|
||||||
let mut events = {
|
let mut events = {
|
||||||
let mut events = vec![];
|
let mut events = vec![];
|
||||||
|
|
||||||
|
@ -400,56 +404,24 @@ impl Transcriber {
|
||||||
}
|
}
|
||||||
|
|
||||||
let future = async move {
|
let future = async move {
|
||||||
enum Winner {
|
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
|
||||||
TranscriptEvent(Option<model::TranscriptEvent>),
|
futures::pin_mut!(timeout);
|
||||||
Timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
let timer = tokio::time::sleep(GRANULARITY.into()).fuse();
|
futures::select! {
|
||||||
futures::pin_mut!(timer);
|
notif = transcript_notif_rx.next() => {
|
||||||
|
if notif.is_none() {
|
||||||
let race_res = futures::select_biased! {
|
// Transcriber loop terminated
|
||||||
transcript_evt = receiver.next() => Winner::TranscriptEvent(transcript_evt),
|
self.state.lock().unwrap().send_eos = true;
|
||||||
_ = timer => Winner::Timeout,
|
return;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
_ = timeout => (),
|
||||||
};
|
};
|
||||||
|
|
||||||
use Winner::*;
|
|
||||||
match race_res {
|
|
||||||
TranscriptEvent(Some(transcript_evt)) => {
|
|
||||||
if let Some(result) = transcript_evt
|
|
||||||
.transcript
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|transcript| transcript.results())
|
|
||||||
.and_then(|results| results.get(0))
|
|
||||||
{
|
|
||||||
gst::trace!(CAT, imp: self, "Received: {result:?}");
|
|
||||||
|
|
||||||
if let Some(alternative) = result
|
|
||||||
.alternatives
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|alternatives| alternatives.get(0))
|
|
||||||
{
|
|
||||||
if let Some(items) = alternative.items() {
|
|
||||||
let mut state = self.state.lock().unwrap();
|
|
||||||
self.enqueue(&mut state, items, result.is_partial)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TranscriptEvent(None) => {
|
|
||||||
gst::info!(CAT, imp: self, "Transcript evt channel disconnected");
|
|
||||||
// Something bad happened elsewhere, let the other side report.
|
|
||||||
return Err(());
|
|
||||||
}
|
|
||||||
Timeout => (),
|
|
||||||
}
|
|
||||||
|
|
||||||
if !self.dequeue() {
|
if !self.dequeue() {
|
||||||
gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing");
|
gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing");
|
||||||
let _ = self.srcpad.pause_task();
|
let _ = self.srcpad.pause_task();
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let _enter = RUNTIME.enter();
|
let _enter = RUNTIME.enter();
|
||||||
|
@ -459,24 +431,19 @@ impl Transcriber {
|
||||||
fn start_task(&self) -> Result<(), gst::LoggableError> {
|
fn start_task(&self) -> Result<(), gst::LoggableError> {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock().unwrap();
|
||||||
|
|
||||||
let (transcript_tx, mut transcript_rx) = mpsc::channel(1);
|
let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
let imp = self.ref_counted();
|
let imp = self.ref_counted();
|
||||||
let res = self.srcpad.start_task(move || {
|
let res = self
|
||||||
if imp.pad_loop_fn(&mut transcript_rx).is_err() {
|
.srcpad
|
||||||
// Pad loop fn reported an unrecoverable error.
|
.start_task(move || imp.pad_loop_fn(&mut transcript_notif_rx));
|
||||||
// FIXME we should probably stop the task as
|
|
||||||
// there's nothing we can do about it except restarting.
|
|
||||||
let _ = imp.srcpad.pause_task();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if res.is_err() {
|
if res.is_err() {
|
||||||
state.transcript_tx = None;
|
state.transcript_notif_tx = None;
|
||||||
return Err(gst::loggable_error!(CAT, "Failed to start pad task"));
|
return Err(gst::loggable_error!(CAT, "Failed to start pad task"));
|
||||||
}
|
}
|
||||||
|
|
||||||
state.transcript_tx = Some(transcript_tx);
|
state.transcript_notif_tx = Some(transcript_notif_tx);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -490,7 +457,7 @@ impl Transcriber {
|
||||||
ws_loop_handle.abort();
|
ws_loop_handle.abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
state.transcript_tx = None;
|
state.transcript_notif_tx = None;
|
||||||
state.buffer_tx = None;
|
state.buffer_tx = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -652,7 +619,8 @@ impl Transcriber {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
let (client_stage, transcription_settings, transcript_tx) = {
|
let (client_stage, transcription_settings, lateness, transcript_notif_tx);
|
||||||
|
{
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock().unwrap();
|
||||||
|
|
||||||
if let Some(ref ws_loop_handle) = state.ws_loop_handle {
|
if let Some(ref ws_loop_handle) = state.ws_loop_handle {
|
||||||
|
@ -667,14 +635,15 @@ impl Transcriber {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let transcript_tx = state
|
transcript_notif_tx = state
|
||||||
.transcript_tx
|
.transcript_notif_tx
|
||||||
.take()
|
.take()
|
||||||
.expect("attempting to spawn the ws loop, but the srcpad task hasn't been started");
|
.expect("attempting to spawn the ws loop, but the srcpad task hasn't been started");
|
||||||
|
|
||||||
let settings = self.settings.lock().unwrap();
|
let settings = self.settings.lock().unwrap();
|
||||||
|
|
||||||
if settings.latency + settings.lateness <= 2 * GRANULARITY {
|
lateness = settings.lateness;
|
||||||
|
if settings.latency + lateness <= 2 * GRANULARITY {
|
||||||
const ERR: &str = "latency + lateness must be greater than 200 milliseconds";
|
const ERR: &str = "latency + lateness must be greater than 200 milliseconds";
|
||||||
gst::error!(CAT, imp: self, "{ERR}");
|
gst::error!(CAT, imp: self, "{ERR}");
|
||||||
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"]));
|
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"]));
|
||||||
|
@ -684,9 +653,9 @@ impl Transcriber {
|
||||||
let s = in_caps.structure(0).unwrap();
|
let s = in_caps.structure(0).unwrap();
|
||||||
let sample_rate = s.get::<i32>("rate").unwrap();
|
let sample_rate = s.get::<i32>("rate").unwrap();
|
||||||
|
|
||||||
let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
|
transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
|
||||||
|
|
||||||
let client_stage = if let Some(client) = state.client.take() {
|
client_stage = if let Some(client) = state.client.take() {
|
||||||
ClientStage::Ready(client)
|
ClientStage::Ready(client)
|
||||||
} else {
|
} else {
|
||||||
ClientStage::NotReady {
|
ClientStage::NotReady {
|
||||||
|
@ -695,8 +664,6 @@ impl Transcriber {
|
||||||
session_token: settings.session_token.to_owned(),
|
session_token: settings.session_token.to_owned(),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
(client_stage, transcription_settings, transcript_tx)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = match client_stage {
|
let client = match client_stage {
|
||||||
|
@ -745,8 +712,9 @@ impl Transcriber {
|
||||||
let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut(
|
let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut(
|
||||||
client,
|
client,
|
||||||
transcription_settings,
|
transcription_settings,
|
||||||
|
lateness,
|
||||||
buffer_rx,
|
buffer_rx,
|
||||||
transcript_tx,
|
transcript_notif_tx,
|
||||||
));
|
));
|
||||||
|
|
||||||
state.ws_loop_handle = Some(ws_loop_handle);
|
state.ws_loop_handle = Some(ws_loop_handle);
|
||||||
|
@ -759,18 +727,19 @@ impl Transcriber {
|
||||||
&self,
|
&self,
|
||||||
client: aws_transcribe::Client,
|
client: aws_transcribe::Client,
|
||||||
settings: TranscriptionSettings,
|
settings: TranscriptionSettings,
|
||||||
|
lateness: gst::ClockTime,
|
||||||
buffer_rx: mpsc::Receiver<gst::Buffer>,
|
buffer_rx: mpsc::Receiver<gst::Buffer>,
|
||||||
transcript_tx: mpsc::Sender<model::TranscriptEvent>,
|
transcript_notif_tx: mpsc::Sender<()>,
|
||||||
) -> impl Future<Output = Result<(), gst::ErrorMessage>> {
|
) -> impl Future<Output = Result<(), gst::ErrorMessage>> {
|
||||||
let imp_weak = self.downgrade();
|
let imp_weak = self.downgrade();
|
||||||
async move {
|
async move {
|
||||||
use gst::glib::subclass::ObjectImplWeakRef;
|
use gst::glib::subclass::ObjectImplWeakRef;
|
||||||
|
|
||||||
// Guard that restores client & transcript_tx when the ws loop is done
|
// Guard that restores client & transcript_notif_tx when the ws loop is done
|
||||||
struct Guard {
|
struct Guard {
|
||||||
imp_weak: ObjectImplWeakRef<Transcriber>,
|
imp_weak: ObjectImplWeakRef<Transcriber>,
|
||||||
client: Option<aws_transcribe::Client>,
|
client: Option<aws_transcribe::Client>,
|
||||||
transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>,
|
transcript_notif_tx: Option<mpsc::Sender<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Guard {
|
impl Guard {
|
||||||
|
@ -778,8 +747,8 @@ impl Transcriber {
|
||||||
self.client.as_ref().unwrap()
|
self.client.as_ref().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn transcript_tx(&mut self) -> &mut mpsc::Sender<model::TranscriptEvent> {
|
fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> {
|
||||||
self.transcript_tx.as_mut().unwrap()
|
self.transcript_notif_tx.as_mut().unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -788,7 +757,7 @@ impl Transcriber {
|
||||||
if let Some(imp) = self.imp_weak.upgrade() {
|
if let Some(imp) = self.imp_weak.upgrade() {
|
||||||
let mut state = imp.state.lock().unwrap();
|
let mut state = imp.state.lock().unwrap();
|
||||||
state.client = self.client.take();
|
state.client = self.client.take();
|
||||||
state.transcript_tx = self.transcript_tx.take();
|
state.transcript_notif_tx = self.transcript_notif_tx.take();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -796,7 +765,7 @@ impl Transcriber {
|
||||||
let mut guard = Guard {
|
let mut guard = Guard {
|
||||||
imp_weak: imp_weak.clone(),
|
imp_weak: imp_weak.clone(),
|
||||||
client: Some(client),
|
client: Some(client),
|
||||||
transcript_tx: Some(transcript_tx),
|
transcript_notif_tx: Some(transcript_notif_tx),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Stream the incoming buffers chunked
|
// Stream the incoming buffers chunked
|
||||||
|
@ -852,9 +821,32 @@ impl Transcriber {
|
||||||
})?
|
})?
|
||||||
{
|
{
|
||||||
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
|
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
|
||||||
if guard.transcript_tx().send(transcript_evt).await.is_err() {
|
let mut enqueued = false;
|
||||||
|
|
||||||
|
if let Some(result) = transcript_evt
|
||||||
|
.transcript
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|transcript| transcript.results())
|
||||||
|
.and_then(|results| results.get(0))
|
||||||
|
{
|
||||||
|
let Some(imp) = imp_weak.upgrade() else { break };
|
||||||
|
|
||||||
|
gst::trace!(CAT, imp: imp, "Received: {result:?}");
|
||||||
|
|
||||||
|
if let Some(alternative) = result
|
||||||
|
.alternatives
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|alternatives| alternatives.get(0))
|
||||||
|
{
|
||||||
|
if let Some(items) = alternative.items() {
|
||||||
|
enqueued = imp.enqueue(items, result.is_partial, lateness);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if enqueued && guard.transcript_notif_tx().send(()).await.is_err() {
|
||||||
if let Some(imp) = imp_weak.upgrade() {
|
if let Some(imp) = imp_weak.upgrade() {
|
||||||
gst::debug!(CAT, imp: imp, "Terminated transcript_evt channel");
|
gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -882,6 +874,7 @@ impl Transcriber {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock().unwrap();
|
||||||
gst::info!(CAT, imp: self, "Unpreparing");
|
gst::info!(CAT, imp: self, "Unpreparing");
|
||||||
self.stop_task();
|
self.stop_task();
|
||||||
|
// Also resets discont to true
|
||||||
*state = State::default();
|
*state = State::default();
|
||||||
gst::info!(CAT, imp: self, "Unprepared");
|
gst::info!(CAT, imp: self, "Unprepared");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue