net/aws/transcriber: track discont offset in input stream

and add it up to subsequent transcripts.

This ensures synchronization is maintained even after the input stream
experiences a discontinuity and a gap in its timestamps.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1230>
This commit is contained in:
Mathieu Duponchelle 2023-06-02 01:02:38 +02:00 committed by GStreamer Marge Bot
parent 2e83107c18
commit 6346d5608e
2 changed files with 71 additions and 18 deletions

View file

@ -143,12 +143,14 @@ impl From<TranslatedItem> for OutputItem {
}
struct State {
buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
// second tuple member is running time
buffer_tx: Option<mpsc::Sender<(gst::Buffer, gst::ClockTime)>>,
transcriber_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
srcpads: BTreeSet<super::TranslateSrcPad>,
pad_serial: u32,
seqnum: gst::Seqnum,
start_time: Option<gst::ClockTime>,
in_segment: gst::FormattedSegment<gst::ClockTime>,
}
impl Default for State {
@ -160,6 +162,7 @@ impl Default for State {
pad_serial: 0,
seqnum: gst::Seqnum::next(),
start_time: None,
in_segment: gst::FormattedSegment::new(),
}
}
}
@ -251,17 +254,21 @@ impl Transcriber {
}
}
Segment(e) => {
let format = e.segment().format();
if format != gst::Format::Time {
let segment = match e.segment().clone().downcast::<gst::ClockTime>() {
Err(segment) => {
gst::element_imp_error!(
self,
gst::StreamError::Format,
["Only Time segments supported, got {format:?}"]
["Only Time segments supported, got {:?}", segment.format(),]
);
return false;
}
Ok(segment) => segment,
};
self.state.lock().unwrap().seqnum = e.seqnum();
let mut state = self.state.lock().unwrap();
state.seqnum = e.seqnum();
state.in_segment = segment;
true
}
@ -297,12 +304,26 @@ impl Transcriber {
gst::FlowError::Error
})?;
let rtime = match self
.state
.lock()
.unwrap()
.in_segment
.to_running_time(buffer.pts())
{
Some(rtime) => rtime,
None => {
gst::debug!(CAT, "Buffer outside segment, clipping (buffer:?)");
return Ok(gst::FlowSuccess::Ok);
}
};
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
gst::log!(CAT, obj: pad, "Flushing");
return Err(gst::FlowError::Flushing);
};
futures::executor::block_on(buffer_tx.send(buffer)).map_err(|err| {
futures::executor::block_on(buffer_tx.send((buffer, rtime))).map_err(|err| {
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
gst::FlowError::Error
})?;

View file

@ -16,7 +16,7 @@ use aws_sdk_transcribestreaming::types;
use futures::channel::mpsc;
use futures::prelude::*;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use super::imp::{Settings, Transcriber};
use super::CAT;
@ -55,11 +55,19 @@ pub struct TranscriptItem {
}
impl TranscriptItem {
pub fn from(item: types::Item, lateness: gst::ClockTime) -> Option<TranscriptItem> {
pub fn from(
item: types::Item,
lateness: gst::ClockTime,
discont_offset: gst::ClockTime,
) -> Option<TranscriptItem> {
let content = item.content?;
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let start_time =
((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
let end_time =
((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
gst::error!(CAT, "Discont offset is {discont_offset}");
Some(TranscriptItem {
pts: start_time,
@ -82,11 +90,17 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
}
}
struct DiscontOffsetTracker {
discont_offset: gst::ClockTime,
last_chained_buffer_rtime: Option<gst::ClockTime>,
}
pub struct TranscriberStream {
imp: glib::subclass::ObjectImplRef<Transcriber>,
output: aws_transcribe::operation::start_stream_transcription::StartStreamTranscriptionOutput,
lateness: gst::ClockTime,
partial_index: usize,
discont_offset_tracker: Arc<Mutex<DiscontOffsetTracker>>,
}
impl TranscriberStream {
@ -94,7 +108,7 @@ impl TranscriberStream {
imp: &Transcriber,
settings: TranscriberSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>,
buffer_rx: mpsc::Receiver<(gst::Buffer, gst::ClockTime)>,
) -> Result<Self, gst::ErrorMessage> {
let client = {
let aws_config = imp.aws_config.lock().unwrap();
@ -104,8 +118,23 @@ impl TranscriberStream {
aws_transcribe::Client::new(aws_config)
};
let discont_offset_tracker = Arc::new(Mutex::new(DiscontOffsetTracker {
discont_offset: gst::ClockTime::ZERO,
last_chained_buffer_rtime: gst::ClockTime::NONE,
}));
let discont_offset_tracker_clone = discont_offset_tracker.clone();
// Stream the incoming buffers chunked
let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
let chunk_stream = buffer_rx.flat_map(move |(buffer, running_time)| {
let mut discont_offset_tracker = discont_offset_tracker_clone.lock().unwrap();
if buffer.flags().contains(gst::BufferFlags::DISCONT) {
if let Some(last_chained_buffer_rtime) = discont_offset_tracker.last_chained_buffer_rtime {
discont_offset_tracker.discont_offset += running_time.saturating_sub(last_chained_buffer_rtime);
}
}
discont_offset_tracker.last_chained_buffer_rtime = Some(running_time);
async_stream::stream! {
let data = buffer.map_readable().unwrap();
use aws_transcribe::{types::{AudioEvent, AudioStream}, primitives::Blob};
@ -146,6 +175,7 @@ impl TranscriberStream {
output,
lateness,
partial_index: 0,
discont_offset_tracker,
})
}
@ -229,7 +259,9 @@ impl TranscriberStream {
break;
}
let Some(item) = TranscriptItem::from(item, self.lateness) else { continue };
let discont_offset = self.discont_offset_tracker.lock().unwrap().discont_offset;
let Some(item) = TranscriptItem::from(item, self.lateness, discont_offset) else { continue };
gst::debug!(
CAT,
imp: self.imp,