mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-30 07:20:59 +00:00
net/aws/transcriber: track discont offset in input stream
and add it up to subsequent transcripts. This ensures synchronization is maintained even after the input stream experiences a discontinuity and a gap in its timestamps. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1230>
This commit is contained in:
parent
2e83107c18
commit
6346d5608e
2 changed files with 71 additions and 18 deletions
|
@ -143,12 +143,14 @@ impl From<TranslatedItem> for OutputItem {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct State {
|
struct State {
|
||||||
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
|
// second tuple member is running time
|
||||||
|
buffer_tx: Option<mpsc::Sender<(gst::Buffer, gst::ClockTime)>>,
|
||||||
transcriber_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
transcriber_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
||||||
srcpads: BTreeSet<super::TranslateSrcPad>,
|
srcpads: BTreeSet<super::TranslateSrcPad>,
|
||||||
pad_serial: u32,
|
pad_serial: u32,
|
||||||
seqnum: gst::Seqnum,
|
seqnum: gst::Seqnum,
|
||||||
start_time: Option<gst::ClockTime>,
|
start_time: Option<gst::ClockTime>,
|
||||||
|
in_segment: gst::FormattedSegment<gst::ClockTime>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for State {
|
impl Default for State {
|
||||||
|
@ -160,6 +162,7 @@ impl Default for State {
|
||||||
pad_serial: 0,
|
pad_serial: 0,
|
||||||
seqnum: gst::Seqnum::next(),
|
seqnum: gst::Seqnum::next(),
|
||||||
start_time: None,
|
start_time: None,
|
||||||
|
in_segment: gst::FormattedSegment::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -251,17 +254,21 @@ impl Transcriber {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Segment(e) => {
|
Segment(e) => {
|
||||||
let format = e.segment().format();
|
let segment = match e.segment().clone().downcast::<gst::ClockTime>() {
|
||||||
if format != gst::Format::Time {
|
Err(segment) => {
|
||||||
gst::element_imp_error!(
|
gst::element_imp_error!(
|
||||||
self,
|
self,
|
||||||
gst::StreamError::Format,
|
gst::StreamError::Format,
|
||||||
["Only Time segments supported, got {format:?}"]
|
["Only Time segments supported, got {:?}", segment.format(),]
|
||||||
);
|
);
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
Ok(segment) => segment,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.state.lock().unwrap().seqnum = e.seqnum();
|
let mut state = self.state.lock().unwrap();
|
||||||
|
state.seqnum = e.seqnum();
|
||||||
|
state.in_segment = segment;
|
||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
@ -297,12 +304,26 @@ impl Transcriber {
|
||||||
gst::FlowError::Error
|
gst::FlowError::Error
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let rtime = match self
|
||||||
|
.state
|
||||||
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.in_segment
|
||||||
|
.to_running_time(buffer.pts())
|
||||||
|
{
|
||||||
|
Some(rtime) => rtime,
|
||||||
|
None => {
|
||||||
|
gst::debug!(CAT, "Buffer outside segment, clipping (buffer:?)");
|
||||||
|
return Ok(gst::FlowSuccess::Ok);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
|
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
|
||||||
gst::log!(CAT, obj: pad, "Flushing");
|
gst::log!(CAT, obj: pad, "Flushing");
|
||||||
return Err(gst::FlowError::Flushing);
|
return Err(gst::FlowError::Flushing);
|
||||||
};
|
};
|
||||||
|
|
||||||
futures::executor::block_on(buffer_tx.send(buffer)).map_err(|err| {
|
futures::executor::block_on(buffer_tx.send((buffer, rtime))).map_err(|err| {
|
||||||
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
|
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
|
||||||
gst::FlowError::Error
|
gst::FlowError::Error
|
||||||
})?;
|
})?;
|
||||||
|
|
|
@ -16,7 +16,7 @@ use aws_sdk_transcribestreaming::types;
|
||||||
use futures::channel::mpsc;
|
use futures::channel::mpsc;
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::imp::{Settings, Transcriber};
|
use super::imp::{Settings, Transcriber};
|
||||||
use super::CAT;
|
use super::CAT;
|
||||||
|
@ -55,11 +55,19 @@ pub struct TranscriptItem {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranscriptItem {
|
impl TranscriptItem {
|
||||||
pub fn from(item: types::Item, lateness: gst::ClockTime) -> Option<TranscriptItem> {
|
pub fn from(
|
||||||
|
item: types::Item,
|
||||||
|
lateness: gst::ClockTime,
|
||||||
|
discont_offset: gst::ClockTime,
|
||||||
|
) -> Option<TranscriptItem> {
|
||||||
let content = item.content?;
|
let content = item.content?;
|
||||||
|
|
||||||
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
let start_time =
|
||||||
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
|
||||||
|
let end_time =
|
||||||
|
((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
|
||||||
|
|
||||||
|
gst::error!(CAT, "Discont offset is {discont_offset}");
|
||||||
|
|
||||||
Some(TranscriptItem {
|
Some(TranscriptItem {
|
||||||
pts: start_time,
|
pts: start_time,
|
||||||
|
@ -82,11 +90,17 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DiscontOffsetTracker {
|
||||||
|
discont_offset: gst::ClockTime,
|
||||||
|
last_chained_buffer_rtime: Option<gst::ClockTime>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TranscriberStream {
|
pub struct TranscriberStream {
|
||||||
imp: glib::subclass::ObjectImplRef<Transcriber>,
|
imp: glib::subclass::ObjectImplRef<Transcriber>,
|
||||||
output: aws_transcribe::operation::start_stream_transcription::StartStreamTranscriptionOutput,
|
output: aws_transcribe::operation::start_stream_transcription::StartStreamTranscriptionOutput,
|
||||||
lateness: gst::ClockTime,
|
lateness: gst::ClockTime,
|
||||||
partial_index: usize,
|
partial_index: usize,
|
||||||
|
discont_offset_tracker: Arc<Mutex<DiscontOffsetTracker>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranscriberStream {
|
impl TranscriberStream {
|
||||||
|
@ -94,7 +108,7 @@ impl TranscriberStream {
|
||||||
imp: &Transcriber,
|
imp: &Transcriber,
|
||||||
settings: TranscriberSettings,
|
settings: TranscriberSettings,
|
||||||
lateness: gst::ClockTime,
|
lateness: gst::ClockTime,
|
||||||
buffer_rx: mpsc::Receiver<gst::Buffer>,
|
buffer_rx: mpsc::Receiver<(gst::Buffer, gst::ClockTime)>,
|
||||||
) -> Result<Self, gst::ErrorMessage> {
|
) -> Result<Self, gst::ErrorMessage> {
|
||||||
let client = {
|
let client = {
|
||||||
let aws_config = imp.aws_config.lock().unwrap();
|
let aws_config = imp.aws_config.lock().unwrap();
|
||||||
|
@ -104,8 +118,23 @@ impl TranscriberStream {
|
||||||
aws_transcribe::Client::new(aws_config)
|
aws_transcribe::Client::new(aws_config)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let discont_offset_tracker = Arc::new(Mutex::new(DiscontOffsetTracker {
|
||||||
|
discont_offset: gst::ClockTime::ZERO,
|
||||||
|
last_chained_buffer_rtime: gst::ClockTime::NONE,
|
||||||
|
}));
|
||||||
|
|
||||||
|
let discont_offset_tracker_clone = discont_offset_tracker.clone();
|
||||||
|
|
||||||
// Stream the incoming buffers chunked
|
// Stream the incoming buffers chunked
|
||||||
let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
|
let chunk_stream = buffer_rx.flat_map(move |(buffer, running_time)| {
|
||||||
|
let mut discont_offset_tracker = discont_offset_tracker_clone.lock().unwrap();
|
||||||
|
if buffer.flags().contains(gst::BufferFlags::DISCONT) {
|
||||||
|
if let Some(last_chained_buffer_rtime) = discont_offset_tracker.last_chained_buffer_rtime {
|
||||||
|
discont_offset_tracker.discont_offset += running_time.saturating_sub(last_chained_buffer_rtime);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
discont_offset_tracker.last_chained_buffer_rtime = Some(running_time);
|
||||||
async_stream::stream! {
|
async_stream::stream! {
|
||||||
let data = buffer.map_readable().unwrap();
|
let data = buffer.map_readable().unwrap();
|
||||||
use aws_transcribe::{types::{AudioEvent, AudioStream}, primitives::Blob};
|
use aws_transcribe::{types::{AudioEvent, AudioStream}, primitives::Blob};
|
||||||
|
@ -146,6 +175,7 @@ impl TranscriberStream {
|
||||||
output,
|
output,
|
||||||
lateness,
|
lateness,
|
||||||
partial_index: 0,
|
partial_index: 0,
|
||||||
|
discont_offset_tracker,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,7 +259,9 @@ impl TranscriberStream {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(item) = TranscriptItem::from(item, self.lateness) else { continue };
|
let discont_offset = self.discont_offset_tracker.lock().unwrap().discont_offset;
|
||||||
|
|
||||||
|
let Some(item) = TranscriptItem::from(item, self.lateness, discont_offset) else { continue };
|
||||||
gst::debug!(
|
gst::debug!(
|
||||||
CAT,
|
CAT,
|
||||||
imp: self.imp,
|
imp: self.imp,
|
||||||
|
|
Loading…
Reference in a new issue