mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-26 13:31:00 +00:00
net/aws/transcriber: track discont offset in input stream
and add it up to subsequent transcripts. This ensures synchronization is maintained even after the input stream experiences a discontinuity and a gap in its timestamps. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1230>
This commit is contained in:
parent
2e83107c18
commit
6346d5608e
2 changed files with 71 additions and 18 deletions
|
@ -143,12 +143,14 @@ impl From<TranslatedItem> for OutputItem {
|
|||
}
|
||||
|
||||
struct State {
|
||||
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
|
||||
// second tuple member is running time
|
||||
buffer_tx: Option<mpsc::Sender<(gst::Buffer, gst::ClockTime)>>,
|
||||
transcriber_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
||||
srcpads: BTreeSet<super::TranslateSrcPad>,
|
||||
pad_serial: u32,
|
||||
seqnum: gst::Seqnum,
|
||||
start_time: Option<gst::ClockTime>,
|
||||
in_segment: gst::FormattedSegment<gst::ClockTime>,
|
||||
}
|
||||
|
||||
impl Default for State {
|
||||
|
@ -160,6 +162,7 @@ impl Default for State {
|
|||
pad_serial: 0,
|
||||
seqnum: gst::Seqnum::next(),
|
||||
start_time: None,
|
||||
in_segment: gst::FormattedSegment::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -251,17 +254,21 @@ impl Transcriber {
|
|||
}
|
||||
}
|
||||
Segment(e) => {
|
||||
let format = e.segment().format();
|
||||
if format != gst::Format::Time {
|
||||
gst::element_imp_error!(
|
||||
self,
|
||||
gst::StreamError::Format,
|
||||
["Only Time segments supported, got {format:?}"]
|
||||
);
|
||||
return false;
|
||||
let segment = match e.segment().clone().downcast::<gst::ClockTime>() {
|
||||
Err(segment) => {
|
||||
gst::element_imp_error!(
|
||||
self,
|
||||
gst::StreamError::Format,
|
||||
["Only Time segments supported, got {:?}", segment.format(),]
|
||||
);
|
||||
return false;
|
||||
}
|
||||
Ok(segment) => segment,
|
||||
};
|
||||
|
||||
self.state.lock().unwrap().seqnum = e.seqnum();
|
||||
let mut state = self.state.lock().unwrap();
|
||||
state.seqnum = e.seqnum();
|
||||
state.in_segment = segment;
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -297,12 +304,26 @@ impl Transcriber {
|
|||
gst::FlowError::Error
|
||||
})?;
|
||||
|
||||
let rtime = match self
|
||||
.state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.in_segment
|
||||
.to_running_time(buffer.pts())
|
||||
{
|
||||
Some(rtime) => rtime,
|
||||
None => {
|
||||
gst::debug!(CAT, "Buffer outside segment, clipping (buffer:?)");
|
||||
return Ok(gst::FlowSuccess::Ok);
|
||||
}
|
||||
};
|
||||
|
||||
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
|
||||
gst::log!(CAT, obj: pad, "Flushing");
|
||||
return Err(gst::FlowError::Flushing);
|
||||
};
|
||||
|
||||
futures::executor::block_on(buffer_tx.send(buffer)).map_err(|err| {
|
||||
futures::executor::block_on(buffer_tx.send((buffer, rtime))).map_err(|err| {
|
||||
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
|
||||
gst::FlowError::Error
|
||||
})?;
|
||||
|
|
|
@ -16,7 +16,7 @@ use aws_sdk_transcribestreaming::types;
|
|||
use futures::channel::mpsc;
|
||||
use futures::prelude::*;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::imp::{Settings, Transcriber};
|
||||
use super::CAT;
|
||||
|
@ -55,11 +55,19 @@ pub struct TranscriptItem {
|
|||
}
|
||||
|
||||
impl TranscriptItem {
|
||||
pub fn from(item: types::Item, lateness: gst::ClockTime) -> Option<TranscriptItem> {
|
||||
pub fn from(
|
||||
item: types::Item,
|
||||
lateness: gst::ClockTime,
|
||||
discont_offset: gst::ClockTime,
|
||||
) -> Option<TranscriptItem> {
|
||||
let content = item.content?;
|
||||
|
||||
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
||||
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
|
||||
let start_time =
|
||||
((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
|
||||
let end_time =
|
||||
((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
|
||||
|
||||
gst::error!(CAT, "Discont offset is {discont_offset}");
|
||||
|
||||
Some(TranscriptItem {
|
||||
pts: start_time,
|
||||
|
@ -82,11 +90,17 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
|
|||
}
|
||||
}
|
||||
|
||||
struct DiscontOffsetTracker {
|
||||
discont_offset: gst::ClockTime,
|
||||
last_chained_buffer_rtime: Option<gst::ClockTime>,
|
||||
}
|
||||
|
||||
pub struct TranscriberStream {
|
||||
imp: glib::subclass::ObjectImplRef<Transcriber>,
|
||||
output: aws_transcribe::operation::start_stream_transcription::StartStreamTranscriptionOutput,
|
||||
lateness: gst::ClockTime,
|
||||
partial_index: usize,
|
||||
discont_offset_tracker: Arc<Mutex<DiscontOffsetTracker>>,
|
||||
}
|
||||
|
||||
impl TranscriberStream {
|
||||
|
@ -94,7 +108,7 @@ impl TranscriberStream {
|
|||
imp: &Transcriber,
|
||||
settings: TranscriberSettings,
|
||||
lateness: gst::ClockTime,
|
||||
buffer_rx: mpsc::Receiver<gst::Buffer>,
|
||||
buffer_rx: mpsc::Receiver<(gst::Buffer, gst::ClockTime)>,
|
||||
) -> Result<Self, gst::ErrorMessage> {
|
||||
let client = {
|
||||
let aws_config = imp.aws_config.lock().unwrap();
|
||||
|
@ -104,8 +118,23 @@ impl TranscriberStream {
|
|||
aws_transcribe::Client::new(aws_config)
|
||||
};
|
||||
|
||||
let discont_offset_tracker = Arc::new(Mutex::new(DiscontOffsetTracker {
|
||||
discont_offset: gst::ClockTime::ZERO,
|
||||
last_chained_buffer_rtime: gst::ClockTime::NONE,
|
||||
}));
|
||||
|
||||
let discont_offset_tracker_clone = discont_offset_tracker.clone();
|
||||
|
||||
// Stream the incoming buffers chunked
|
||||
let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
|
||||
let chunk_stream = buffer_rx.flat_map(move |(buffer, running_time)| {
|
||||
let mut discont_offset_tracker = discont_offset_tracker_clone.lock().unwrap();
|
||||
if buffer.flags().contains(gst::BufferFlags::DISCONT) {
|
||||
if let Some(last_chained_buffer_rtime) = discont_offset_tracker.last_chained_buffer_rtime {
|
||||
discont_offset_tracker.discont_offset += running_time.saturating_sub(last_chained_buffer_rtime);
|
||||
}
|
||||
}
|
||||
|
||||
discont_offset_tracker.last_chained_buffer_rtime = Some(running_time);
|
||||
async_stream::stream! {
|
||||
let data = buffer.map_readable().unwrap();
|
||||
use aws_transcribe::{types::{AudioEvent, AudioStream}, primitives::Blob};
|
||||
|
@ -146,6 +175,7 @@ impl TranscriberStream {
|
|||
output,
|
||||
lateness,
|
||||
partial_index: 0,
|
||||
discont_offset_tracker,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -229,7 +259,9 @@ impl TranscriberStream {
|
|||
break;
|
||||
}
|
||||
|
||||
let Some(item) = TranscriptItem::from(item, self.lateness) else { continue };
|
||||
let discont_offset = self.discont_offset_tracker.lock().unwrap().discont_offset;
|
||||
|
||||
let Some(item) = TranscriptItem::from(item, self.lateness, discont_offset) else { continue };
|
||||
gst::debug!(
|
||||
CAT,
|
||||
imp: self.imp,
|
||||
|
|
Loading…
Reference in a new issue