mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-09-27 06:20:03 +00:00
6346d5608e
and add it up to subsequent transcripts. This ensures synchronization is maintained even after the input stream experiences a discontinuity and a gap in its timestamps. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1230>
287 lines
9.5 KiB
Rust
287 lines
9.5 KiB
Rust
// Copyright (C) 2020 Mathieu Duponchelle <mathieu@centricular.com>
|
|
// Copyright (C) 2023 François Laignel <francois@centricular.com>
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
|
|
// If a copy of the MPL was not distributed with this file, You can obtain one at
|
|
// <https://mozilla.org/MPL/2.0/>.
|
|
//
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
use gst::subclass::prelude::*;
|
|
use gst::{glib, prelude::*};
|
|
|
|
use aws_sdk_transcribestreaming as aws_transcribe;
|
|
use aws_sdk_transcribestreaming::types;
|
|
|
|
use futures::channel::mpsc;
|
|
use futures::prelude::*;
|
|
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use super::imp::{Settings, Transcriber};
|
|
use super::CAT;
|
|
|
|
#[derive(Debug)]
|
|
pub struct TranscriberSettings {
|
|
lang_code: types::LanguageCode,
|
|
sample_rate: i32,
|
|
vocabulary: Option<String>,
|
|
vocabulary_filter: Option<String>,
|
|
vocabulary_filter_method: types::VocabularyFilterMethod,
|
|
session_id: Option<String>,
|
|
results_stability: types::PartialResultsStability,
|
|
}
|
|
|
|
impl TranscriberSettings {
|
|
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
|
|
TranscriberSettings {
|
|
lang_code: settings.language_code.as_str().into(),
|
|
sample_rate,
|
|
vocabulary: settings.vocabulary.clone(),
|
|
vocabulary_filter: settings.vocabulary_filter.clone(),
|
|
vocabulary_filter_method: settings.vocabulary_filter_method.into(),
|
|
session_id: settings.session_id.clone(),
|
|
results_stability: settings.results_stability.into(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Default)]
|
|
pub struct TranscriptItem {
|
|
pub pts: gst::ClockTime,
|
|
pub duration: gst::ClockTime,
|
|
pub content: String,
|
|
pub is_punctuation: bool,
|
|
}
|
|
|
|
impl TranscriptItem {
|
|
pub fn from(
|
|
item: types::Item,
|
|
lateness: gst::ClockTime,
|
|
discont_offset: gst::ClockTime,
|
|
) -> Option<TranscriptItem> {
|
|
let content = item.content?;
|
|
|
|
let start_time =
|
|
((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
|
|
let end_time =
|
|
((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness + discont_offset;
|
|
|
|
gst::error!(CAT, "Discont offset is {discont_offset}");
|
|
|
|
Some(TranscriptItem {
|
|
pts: start_time,
|
|
duration: end_time - start_time,
|
|
content,
|
|
is_punctuation: matches!(item.r#type, Some(types::ItemType::Punctuation)),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub enum TranscriptEvent {
|
|
Items(Arc<Vec<TranscriptItem>>),
|
|
Eos,
|
|
}
|
|
|
|
impl From<Vec<TranscriptItem>> for TranscriptEvent {
|
|
fn from(transcript_items: Vec<TranscriptItem>) -> Self {
|
|
TranscriptEvent::Items(transcript_items.into())
|
|
}
|
|
}
|
|
|
|
struct DiscontOffsetTracker {
|
|
discont_offset: gst::ClockTime,
|
|
last_chained_buffer_rtime: Option<gst::ClockTime>,
|
|
}
|
|
|
|
pub struct TranscriberStream {
|
|
imp: glib::subclass::ObjectImplRef<Transcriber>,
|
|
output: aws_transcribe::operation::start_stream_transcription::StartStreamTranscriptionOutput,
|
|
lateness: gst::ClockTime,
|
|
partial_index: usize,
|
|
discont_offset_tracker: Arc<Mutex<DiscontOffsetTracker>>,
|
|
}
|
|
|
|
impl TranscriberStream {
|
|
pub async fn try_new(
|
|
imp: &Transcriber,
|
|
settings: TranscriberSettings,
|
|
lateness: gst::ClockTime,
|
|
buffer_rx: mpsc::Receiver<(gst::Buffer, gst::ClockTime)>,
|
|
) -> Result<Self, gst::ErrorMessage> {
|
|
let client = {
|
|
let aws_config = imp.aws_config.lock().unwrap();
|
|
let aws_config = aws_config
|
|
.as_ref()
|
|
.expect("aws_config must be initialized at this stage");
|
|
aws_transcribe::Client::new(aws_config)
|
|
};
|
|
|
|
let discont_offset_tracker = Arc::new(Mutex::new(DiscontOffsetTracker {
|
|
discont_offset: gst::ClockTime::ZERO,
|
|
last_chained_buffer_rtime: gst::ClockTime::NONE,
|
|
}));
|
|
|
|
let discont_offset_tracker_clone = discont_offset_tracker.clone();
|
|
|
|
// Stream the incoming buffers chunked
|
|
let chunk_stream = buffer_rx.flat_map(move |(buffer, running_time)| {
|
|
let mut discont_offset_tracker = discont_offset_tracker_clone.lock().unwrap();
|
|
if buffer.flags().contains(gst::BufferFlags::DISCONT) {
|
|
if let Some(last_chained_buffer_rtime) = discont_offset_tracker.last_chained_buffer_rtime {
|
|
discont_offset_tracker.discont_offset += running_time.saturating_sub(last_chained_buffer_rtime);
|
|
}
|
|
}
|
|
|
|
discont_offset_tracker.last_chained_buffer_rtime = Some(running_time);
|
|
async_stream::stream! {
|
|
let data = buffer.map_readable().unwrap();
|
|
use aws_transcribe::{types::{AudioEvent, AudioStream}, primitives::Blob};
|
|
for chunk in data.chunks(8192) {
|
|
yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build()));
|
|
}
|
|
}
|
|
});
|
|
|
|
let mut transcribe_builder = client
|
|
.start_stream_transcription()
|
|
.language_code(settings.lang_code)
|
|
.media_sample_rate_hertz(settings.sample_rate)
|
|
.media_encoding(types::MediaEncoding::Pcm)
|
|
.enable_partial_results_stabilization(true)
|
|
.partial_results_stability(settings.results_stability)
|
|
.set_vocabulary_name(settings.vocabulary)
|
|
.set_session_id(settings.session_id);
|
|
|
|
if let Some(vocabulary_filter) = settings.vocabulary_filter {
|
|
transcribe_builder = transcribe_builder
|
|
.vocabulary_filter_name(vocabulary_filter)
|
|
.vocabulary_filter_method(settings.vocabulary_filter_method);
|
|
}
|
|
|
|
let output = transcribe_builder
|
|
.audio_stream(chunk_stream.into())
|
|
.send()
|
|
.await
|
|
.map_err(|err| {
|
|
let err = format!("Transcribe ws init error: {err}");
|
|
gst::error!(CAT, imp: imp, "{err}");
|
|
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
|
|
})?;
|
|
|
|
Ok(TranscriberStream {
|
|
imp: imp.ref_counted(),
|
|
output,
|
|
lateness,
|
|
partial_index: 0,
|
|
discont_offset_tracker,
|
|
})
|
|
}
|
|
|
|
pub async fn next(&mut self) -> Result<TranscriptEvent, gst::ErrorMessage> {
|
|
loop {
|
|
let event = self
|
|
.output
|
|
.transcript_result_stream
|
|
.recv()
|
|
.await
|
|
.map_err(|err| {
|
|
let err = format!("Transcribe ws stream error: {err}");
|
|
gst::error!(CAT, imp: self.imp, "{err}");
|
|
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
|
|
})?;
|
|
|
|
let Some(event) = event else {
|
|
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
|
|
return Ok(TranscriptEvent::Eos);
|
|
};
|
|
|
|
if let types::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
|
|
let mut ready_items = None;
|
|
|
|
if let Some(result) = transcript_evt
|
|
.transcript
|
|
.and_then(|transcript| transcript.results)
|
|
.and_then(|mut results| results.drain(..).next())
|
|
{
|
|
gst::trace!(CAT, imp: self.imp, "Received: {result:?}");
|
|
|
|
if let Some(alternative) = result
|
|
.alternatives
|
|
.and_then(|mut alternatives| alternatives.drain(..).next())
|
|
{
|
|
ready_items = alternative.items.and_then(|items| {
|
|
self.get_ready_transcript_items(items, result.is_partial)
|
|
});
|
|
}
|
|
}
|
|
|
|
if let Some(ready_items) = ready_items {
|
|
return Ok(ready_items.into());
|
|
}
|
|
} else {
|
|
gst::warning!(
|
|
CAT,
|
|
imp: self.imp,
|
|
"Transcribe ws returned unknown event: consider upgrading the SDK"
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Builds a list from the provided stable items.
|
|
fn get_ready_transcript_items(
|
|
&mut self,
|
|
mut items: Vec<types::Item>,
|
|
partial: bool,
|
|
) -> Option<Vec<TranscriptItem>> {
|
|
if items.len() <= self.partial_index {
|
|
gst::error!(
|
|
CAT,
|
|
imp: self.imp,
|
|
"sanity check failed, alternative length {} < partial_index {}",
|
|
items.len(),
|
|
self.partial_index
|
|
);
|
|
|
|
if !partial {
|
|
self.partial_index = 0;
|
|
}
|
|
|
|
return None;
|
|
}
|
|
|
|
let mut output = vec![];
|
|
|
|
for item in items.drain(self.partial_index..) {
|
|
if !item.stable().unwrap_or(false) {
|
|
break;
|
|
}
|
|
|
|
let discont_offset = self.discont_offset_tracker.lock().unwrap().discont_offset;
|
|
|
|
let Some(item) = TranscriptItem::from(item, self.lateness, discont_offset) else { continue };
|
|
gst::debug!(
|
|
CAT,
|
|
imp: self.imp,
|
|
"Item is ready for queuing: {}, PTS {}",
|
|
item.content,
|
|
item.pts,
|
|
);
|
|
|
|
self.partial_index += 1;
|
|
output.push(item);
|
|
}
|
|
|
|
if !partial {
|
|
self.partial_index = 0;
|
|
}
|
|
|
|
if output.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
Some(output)
|
|
}
|
|
}
|