awstranscribe: output original transcripts to separate pad

When the transcriber is used in a live situation, it can be useful
to save a transcript for editing after the fact when producing a
VOD.

Each source pad now gets an "unsynced_" pendant. That unsynced pad
is pushed to from the context of the "live" source pad task. Flow
returns from the unsynced pads are ignored, we simply check the
last flow return before attempting to push the next transcript.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1915>
This commit is contained in:
Mathieu Duponchelle 2024-11-13 15:55:35 +01:00 committed by GStreamer Marge Bot
parent 39e9ad1d29
commit b5bd7d047c
8 changed files with 536 additions and 54 deletions

1
Cargo.lock generated
View file

@ -2458,6 +2458,7 @@ dependencies = [
"serde", "serde",
"serde_derive", "serde_derive",
"serde_json", "serde_json",
"serde_with",
"sprintf", "sprintf",
"test-with", "test-with",
"tokio", "tokio",

View file

@ -1162,6 +1162,16 @@
"direction": "src", "direction": "src",
"presence": "request", "presence": "request",
"type": "GstTranslateSrcPad" "type": "GstTranslateSrcPad"
},
"unsynced_src": {
"caps": "application/x-json:\n",
"direction": "src",
"presence": "always"
},
"unsynced_translate_src_%%u": {
"caps": "application/x-json:\n",
"direction": "src",
"presence": "request"
} }
}, },
"properties": { "properties": {

View file

@ -31,6 +31,7 @@ tokio = { version = "1.0", features = [ "full" ] }
serde = "1" serde = "1"
serde_derive = "1" serde_derive = "1"
serde_json = "1" serde_json = "1"
serde_with = "3"
url = "2" url = "2"
gst-video = { workspace = true, features = ["v1_22"] } gst-video = { workspace = true, features = ["v1_22"] }
sprintf = "0.2" sprintf = "0.2"

View file

@ -35,7 +35,7 @@ use std::sync::{Arc, Mutex};
use std::sync::LazyLock; use std::sync::LazyLock;
use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem}; use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem};
use super::translate::{TranslateLoop, TranslatedItem}; use super::translate::{TranslateLoop, TranslatedItem, Translation};
use super::{ use super::{
AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod, AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod,
TranslationTokenizationMethod, CAT, TranslationTokenizationMethod, CAT,
@ -462,9 +462,16 @@ impl Transcriber {
use TranscriptEvent::*; use TranscriptEvent::*;
match res { match res {
None => (), None => (),
Some(Items(items)) => { Some(Transcript {
items,
serialized
}) => {
if imp.transcript_event_tx.receiver_count() > 0 { if imp.transcript_event_tx.receiver_count() > 0 {
let _ = imp.transcript_event_tx.send(Items(items.clone())); let _ = imp.transcript_event_tx.send(
Transcript {
items: items.clone(),
serialized: serialized.clone()
});
} }
if imp.transcript_event_for_translate_tx.receiver_count() > 0 { if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
@ -472,7 +479,10 @@ impl Transcriber {
if let Some(items_to_translate) = translate_queue.push(item) { if let Some(items_to_translate) = translate_queue.push(item) {
let _ = imp let _ = imp
.transcript_event_for_translate_tx .transcript_event_for_translate_tx
.send(Items(items_to_translate.into())); .send(Transcript {
items: items_to_translate.into(),
serialized: None,
});
} }
} }
} }
@ -489,7 +499,10 @@ impl Transcriber {
translate_queue.drain().collect(); translate_queue.drain().collect();
let _ = imp let _ = imp
.transcript_event_for_translate_tx .transcript_event_for_translate_tx
.send(Items(items_to_translate.into())); .send(Transcript {
items: items_to_translate.into(),
serialized: None,
});
let _ = imp.transcript_event_for_translate_tx.send(Eos); let _ = imp.transcript_event_for_translate_tx.send(Eos);
} }
@ -518,7 +531,10 @@ impl Transcriber {
); );
let _ = imp let _ = imp
.transcript_event_for_translate_tx .transcript_event_for_translate_tx
.send(Items(items_to_translate.into())); .send(Transcript {
items: items_to_translate.into(),
serialized: None,
});
} }
} }
} }
@ -662,12 +678,20 @@ impl ObjectSubclass for Transcriber {
.flags(gst::PadFlags::FIXED_CAPS) .flags(gst::PadFlags::FIXED_CAPS)
.build(); .build();
let templ = klass.pad_template("unsynced_src").unwrap();
let static_unsynced_srcpad = gst::PadBuilder::<gst::Pad>::from_template(&templ)
.flags(gst::PadFlags::FIXED_CAPS)
.build();
// Setting the channel capacity so that a TranslateSrcPad that would lag // Setting the channel capacity so that a TranslateSrcPad that would lag
// behind for some reasons get a chance to catch-up without loosing items. // behind for some reasons get a chance to catch-up without loosing items.
// Receiver will be created by subscribing to sender later. // Receiver will be created by subscribing to sender later.
let (transcript_event_for_translate_tx, _) = broadcast::channel(128); let (transcript_event_for_translate_tx, _) = broadcast::channel(128);
let (transcript_event_tx, _) = broadcast::channel(128); let (transcript_event_tx, _) = broadcast::channel(128);
static_srcpad
.imp()
.set_unsynced_pad(&static_unsynced_srcpad);
Self { Self {
static_srcpad, static_srcpad,
sinkpad, sinkpad,
@ -785,6 +809,17 @@ impl ObjectImpl for Transcriber {
let obj = self.obj(); let obj = self.obj();
obj.add_pad(&self.sinkpad).unwrap(); obj.add_pad(&self.sinkpad).unwrap();
obj.add_pad(&self.static_srcpad).unwrap(); obj.add_pad(&self.static_srcpad).unwrap();
obj.add_pad(
self.static_srcpad
.imp()
.state
.lock()
.unwrap()
.unsynced_pad
.as_ref()
.unwrap(),
)
.unwrap();
obj.set_element_flags(gst::ElementFlags::PROVIDE_CLOCK | gst::ElementFlags::REQUIRE_CLOCK); obj.set_element_flags(gst::ElementFlags::PROVIDE_CLOCK | gst::ElementFlags::REQUIRE_CLOCK);
} }
@ -958,6 +993,21 @@ impl ElementImpl for Transcriber {
super::TranslateSrcPad::static_type(), super::TranslateSrcPad::static_type(),
) )
.unwrap(); .unwrap();
let src_caps = gst::Caps::builder("application/x-json").build();
let unsynced_src_pad_template = gst::PadTemplate::new(
"unsynced_src",
gst::PadDirection::Src,
gst::PadPresence::Always,
&src_caps,
)
.unwrap();
let unsynced_sometimes_src_pad_template = gst::PadTemplate::new(
"unsynced_translate_src_%u",
gst::PadDirection::Src,
gst::PadPresence::Request,
&src_caps,
)
.unwrap();
let sink_caps = gst_audio::AudioCapsBuilder::new() let sink_caps = gst_audio::AudioCapsBuilder::new()
.format(gst_audio::AudioFormat::S16le) .format(gst_audio::AudioFormat::S16le)
@ -972,7 +1022,13 @@ impl ElementImpl for Transcriber {
) )
.unwrap(); .unwrap();
vec![src_pad_template, req_src_pad_template, sink_pad_template] vec![
src_pad_template,
req_src_pad_template,
unsynced_src_pad_template,
unsynced_sometimes_src_pad_template,
sink_pad_template,
]
}); });
PAD_TEMPLATES.as_ref() PAD_TEMPLATES.as_ref()
@ -1041,12 +1097,25 @@ impl ElementImpl for Transcriber {
.flags(gst::PadFlags::FIXED_CAPS) .flags(gst::PadFlags::FIXED_CAPS)
.build(); .build();
let templ = self
.obj()
.class()
.pad_template("unsynced_translate_src_%u")
.unwrap();
let static_unsynced_srcpad = gst::PadBuilder::<gst::Pad>::from_template(&templ)
.name(format!("unsynced_translate_src_{}", state.pad_serial).as_str())
.flags(gst::PadFlags::FIXED_CAPS)
.build();
pad.imp().set_unsynced_pad(&static_unsynced_srcpad);
state.srcpads.insert(pad.clone()); state.srcpads.insert(pad.clone());
state.pad_serial += 1; state.pad_serial += 1;
drop(state); drop(state);
self.obj().add_pad(&pad).unwrap(); self.obj().add_pad(&pad).unwrap();
self.obj().add_pad(&static_unsynced_srcpad).unwrap();
let _ = self let _ = self
.obj() .obj()
@ -1059,6 +1128,7 @@ impl ElementImpl for Transcriber {
fn release_pad(&self, pad: &gst::Pad) { fn release_pad(&self, pad: &gst::Pad) {
pad.set_active(false).unwrap(); pad.set_active(false).unwrap();
self.obj().remove_pad(pad).unwrap(); self.obj().remove_pad(pad).unwrap();
self.state.lock().unwrap().srcpads.remove(pad);
self.obj().child_removed(pad, &pad.name()); self.obj().child_removed(pad, &pad.name());
let _ = self let _ = self
@ -1098,6 +1168,7 @@ impl ChildProxyImpl for Transcriber {
.map(|p| p.upcast()) .map(|p| p.upcast())
} }
} }
struct TranslationPadTask { struct TranslationPadTask {
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>, pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
elem: super::Transcriber, elem: super::Transcriber,
@ -1105,13 +1176,15 @@ struct TranslationPadTask {
needs_translate: bool, needs_translate: bool,
translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>, translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
to_translate_tx: Option<mpsc::Sender<Arc<Vec<TranscriptItem>>>>, to_translate_tx: Option<mpsc::Sender<Arc<Vec<TranscriptItem>>>>,
from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>, from_translate_rx: Option<mpsc::Receiver<Translation>>,
send_events: bool, send_events: bool,
output_items: VecDeque<OutputItem>, output_items: VecDeque<OutputItem>,
our_latency: gst::ClockTime, our_latency: gst::ClockTime,
seqnum: gst::Seqnum, seqnum: gst::Seqnum,
send_eos: bool, send_eos: bool,
pending_translations: usize, pending_translations: usize,
unsynced_pad: Option<gst::Pad>,
unsynced: Option<String>,
} }
impl TranslationPadTask { impl TranslationPadTask {
@ -1189,6 +1262,8 @@ impl TranslationPadTask {
seqnum: gst::Seqnum::next(), seqnum: gst::Seqnum::next(),
send_eos: false, send_eos: false,
pending_translations: 0, pending_translations: 0,
unsynced_pad: pad.state.lock().unwrap().unsynced_pad.clone(),
unsynced: None,
}) })
} }
} }
@ -1234,8 +1309,12 @@ impl TranslationPadTask {
use TranscriptEvent::*; use TranscriptEvent::*;
use broadcast::error::RecvError; use broadcast::error::RecvError;
match items_res { match items_res {
Ok(Items(transcript_items)) => { Ok(Transcript {
self.output_items.extend(transcript_items.iter().map(Into::into)); items,
serialized,
}) => {
self.unsynced = serialized;
self.output_items.extend(items.iter().map(Into::into));
} }
Ok(Eos) => { Ok(Eos) => {
gst::debug!(CAT, imp = self.pad, "Got eos"); gst::debug!(CAT, imp = self.pad, "Got eos");
@ -1282,7 +1361,10 @@ impl TranslationPadTask {
use TranscriptEvent::*; use TranscriptEvent::*;
use broadcast::error::RecvError; use broadcast::error::RecvError;
match items_res { match items_res {
Ok(Items(items_to_translate)) => Some(items_to_translate), Ok(Transcript {
items,
..
}) => Some(items),
Ok(Eos) => { Ok(Eos) => {
gst::debug!(CAT, imp = self.pad, "Got eos"); gst::debug!(CAT, imp = self.pad, "Got eos");
self.send_eos = true; self.send_eos = true;
@ -1328,15 +1410,24 @@ impl TranslationPadTask {
.as_mut() .as_mut()
.expect("from_translation chan must be available in translation mode"); .expect("from_translation chan must be available in translation mode");
while let Ok(translated_items) = from_translate_rx.try_next() { while let Ok(translation) = from_translate_rx.try_next() {
let Some(translated_items) = translated_items else { let Some(translation) = translation else {
const ERR: &str = "translation chan terminated"; const ERR: &str = "translation chan terminated";
gst::debug!(CAT, imp = self.pad, "{ERR}"); gst::debug!(CAT, imp = self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}; };
if let Some(pts) = translation.items.first().map(|i| i.pts) {
self.unsynced = Some(
serde_json::json!({
"translation": translation.translation,
"start_time": *pts,
})
.to_string(),
);
}
self.output_items self.output_items
.extend(translated_items.into_iter().map(Into::into)); .extend(translation.items.into_iter().map(Into::into));
self.pending_translations = self.pending_translations.saturating_sub(1); self.pending_translations = self.pending_translations.saturating_sub(1);
} }
@ -1362,6 +1453,35 @@ impl TranslationPadTask {
(last_position, state.discont_pending) (last_position, state.discont_pending)
}; };
if let Some(unsynced) = self.unsynced.take() {
if let Some(ref unsynced_pad) = self.unsynced_pad {
if unsynced_pad.last_flow_result().is_ok() {
gst::log!(
CAT,
obj = unsynced_pad,
"pushing serialized transcript with timestamp {now}"
);
gst::trace!(CAT, obj = unsynced_pad, "serialized transcript: {unsynced}");
let mut buf = gst::Buffer::from_mut_slice(unsynced.into_bytes());
{
let buf_mut = buf.get_mut().unwrap();
buf_mut.set_pts(now);
}
let _ = unsynced_pad.push(buf);
} else {
gst::log!(
CAT,
obj = unsynced_pad,
"not pushing serialized transcript, last flow result: {:?}",
unsynced_pad.last_flow_result()
);
}
}
}
/* First, check our pending buffers */ /* First, check our pending buffers */
while let Some(item) = self.output_items.front() { while let Some(item) = self.output_items.front() {
// Note: items pts start from 0 + lateness // Note: items pts start from 0 + lateness
@ -1531,6 +1651,7 @@ impl TranslationPadTask {
} }
let mut events = vec![]; let mut events = vec![];
let mut unsynced_events = vec![];
{ {
let elem_imp = self.elem.imp(); let elem_imp = self.elem.imp();
@ -1548,16 +1669,31 @@ impl TranslationPadTask {
.build(), .build(),
); );
unsynced_events.push(
gst::event::StreamStart::builder("unsynced-transcription")
.seqnum(self.seqnum)
.build(),
);
let caps = gst::Caps::builder("text/x-raw") let caps = gst::Caps::builder("text/x-raw")
.field("format", "utf8") .field("format", "utf8")
.build(); .build();
events.push(gst::event::Caps::builder(&caps).seqnum(self.seqnum).build()); events.push(gst::event::Caps::builder(&caps).seqnum(self.seqnum).build());
let caps = gst::Caps::builder("application/x-json").build();
unsynced_events.push(gst::event::Caps::builder(&caps).seqnum(self.seqnum).build());
events.push( events.push(
gst::event::Segment::builder(&pad_state.out_segment) gst::event::Segment::builder(&pad_state.out_segment)
.seqnum(self.seqnum) .seqnum(self.seqnum)
.build(), .build(),
); );
unsynced_events.push(
gst::event::Segment::builder(&pad_state.out_segment)
.seqnum(self.seqnum)
.build(),
);
} }
for event in events.drain(..) { for event in events.drain(..) {
@ -1569,6 +1705,13 @@ impl TranslationPadTask {
} }
} }
if let Some(ref unsynced_pad) = self.unsynced_pad {
for event in unsynced_events.drain(..) {
gst::info!(CAT, obj = unsynced_pad, "Sending {event:?}");
let _ = unsynced_pad.push_event(event);
}
}
self.send_events = false; self.send_events = false;
Ok(()) Ok(())
@ -1581,6 +1724,7 @@ struct TranslationPadState {
out_segment: gst::FormattedSegment<gst::ClockTime>, out_segment: gst::FormattedSegment<gst::ClockTime>,
task_abort_handle: Option<AbortHandle>, task_abort_handle: Option<AbortHandle>,
start_time: Option<gst::ClockTime>, start_time: Option<gst::ClockTime>,
unsynced_pad: Option<gst::Pad>,
} }
impl Default for TranslationPadState { impl Default for TranslationPadState {
@ -1590,6 +1734,7 @@ impl Default for TranslationPadState {
out_segment: Default::default(), out_segment: Default::default(),
task_abort_handle: None, task_abort_handle: None,
start_time: None, start_time: None,
unsynced_pad: None,
} }
} }
} }
@ -1607,6 +1752,10 @@ pub struct TranslateSrcPad {
} }
impl TranslateSrcPad { impl TranslateSrcPad {
fn set_unsynced_pad(&self, pad: &gst::Pad) {
self.state.lock().unwrap().unsynced_pad = Some(pad.clone());
}
fn start_task(&self) -> Result<(), gst::LoggableError> { fn start_task(&self) -> Result<(), gst::LoggableError> {
gst::debug!(CAT, imp = self, "Starting task"); gst::debug!(CAT, imp = self, "Starting task");

View file

@ -10,6 +10,7 @@ use gst::glib;
use gst::prelude::*; use gst::prelude::*;
mod imp; mod imp;
mod remote_types;
mod transcribe; mod transcribe;
mod translate; mod translate;

View file

@ -0,0 +1,305 @@
use aws_sdk_transcribestreaming::types as sdk_types;
use serde::{Serialize, Serializer};
use serde_with::{serde_as, SerializeAs};
#[serde_as]
#[derive(serde_derive::Serialize)]
struct EntityDef {
start_time: f64,
end_time: f64,
category: Option<String>,
r#type: Option<String>,
content: Option<String>,
confidence: Option<f64>,
}
impl SerializeAs<sdk_types::Entity> for EntityDef {
fn serialize_as<S>(value: &sdk_types::Entity, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
EntityDef {
start_time: value.start_time,
end_time: value.end_time,
category: value.category.clone(),
r#type: value.r#type.clone(),
content: value.content.clone(),
confidence: value.confidence,
}
.serialize(serializer)
}
}
#[serde_as]
#[derive(serde_derive::Serialize)]
enum ItemTypeDef {
Pronunciation,
Punctuation,
Unknown,
}
impl SerializeAs<sdk_types::ItemType> for ItemTypeDef {
fn serialize_as<S>(value: &sdk_types::ItemType, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use sdk_types::ItemType::*;
match value {
Pronunciation => ItemTypeDef::Pronunciation,
Punctuation => ItemTypeDef::Punctuation,
_ => ItemTypeDef::Unknown,
}
.serialize(serializer)
}
}
#[serde_as]
#[derive(serde_derive::Serialize)]
struct ItemDef {
start_time: f64,
end_time: f64,
#[serde_as(as = "Option<ItemTypeDef>")]
r#type: Option<sdk_types::ItemType>,
content: Option<String>,
vocabulary_filter_match: bool,
speaker: Option<String>,
confidence: Option<f64>,
stable: Option<bool>,
}
impl SerializeAs<sdk_types::Item> for ItemDef {
fn serialize_as<S>(value: &sdk_types::Item, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
ItemDef {
start_time: value.start_time,
end_time: value.end_time,
r#type: value.r#type.clone(),
content: value.content.clone(),
vocabulary_filter_match: value.vocabulary_filter_match,
speaker: value.speaker.clone(),
confidence: value.confidence,
stable: value.stable,
}
.serialize(serializer)
}
}
#[serde_as]
#[derive(serde_derive::Serialize)]
struct AlternativeDef {
transcript: Option<String>,
#[serde_as(as = "Option<Vec<ItemDef>>")]
items: Option<Vec<sdk_types::Item>>,
#[serde_as(as = "Option<Vec<EntityDef>>")]
entities: Option<Vec<sdk_types::Entity>>,
}
impl SerializeAs<sdk_types::Alternative> for AlternativeDef {
fn serialize_as<S>(value: &sdk_types::Alternative, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
AlternativeDef {
transcript: value.transcript.clone(),
items: value.items.clone(),
entities: value.entities.clone(),
}
.serialize(serializer)
}
}
#[serde_as]
#[derive(serde_derive::Serialize)]
struct LanguageWithScoreDef {
#[serde_as(as = "Option<LanguageCodeDef>")]
language_code: Option<sdk_types::LanguageCode>,
score: f64,
}
impl SerializeAs<sdk_types::LanguageWithScore> for LanguageWithScoreDef {
fn serialize_as<S>(
value: &sdk_types::LanguageWithScore,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
LanguageWithScoreDef {
language_code: value.language_code.clone(),
score: value.score,
}
.serialize(serializer)
}
}
#[serde_as]
#[derive(serde_derive::Serialize)]
enum LanguageCodeDef {
AfZa,
ArAe,
ArSa,
CaEs,
CsCz,
DaDk,
DeCh,
DeDe,
ElGr,
EnAb,
EnAu,
EnGb,
EnIe,
EnIn,
EnNz,
EnUs,
EnWl,
EnZa,
EsEs,
EsUs,
EuEs,
FaIr,
FiFi,
FrCa,
FrFr,
GlEs,
HeIl,
HiIn,
HrHr,
IdId,
ItIt,
JaJp,
KoKr,
LvLv,
MsMy,
NlNl,
NoNo,
PlPl,
PtBr,
PtPt,
RoRo,
RuRu,
SkSk,
SoSo,
SrRs,
SvSe,
ThTh,
TlPh,
UkUa,
ViVn,
ZhCn,
ZhHk,
ZhTw,
ZuZa,
Unknown,
}
impl SerializeAs<sdk_types::LanguageCode> for LanguageCodeDef {
fn serialize_as<S>(value: &sdk_types::LanguageCode, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use sdk_types::LanguageCode::*;
match value {
AfZa => LanguageCodeDef::AfZa,
ArAe => LanguageCodeDef::ArAe,
ArSa => LanguageCodeDef::ArSa,
CaEs => LanguageCodeDef::CaEs,
CsCz => LanguageCodeDef::CsCz,
DaDk => LanguageCodeDef::DaDk,
DeCh => LanguageCodeDef::DeCh,
DeDe => LanguageCodeDef::DeDe,
ElGr => LanguageCodeDef::ElGr,
EnAb => LanguageCodeDef::EnAb,
EnAu => LanguageCodeDef::EnAu,
EnGb => LanguageCodeDef::EnGb,
EnIe => LanguageCodeDef::EnIe,
EnIn => LanguageCodeDef::EnIn,
EnNz => LanguageCodeDef::EnNz,
EnUs => LanguageCodeDef::EnUs,
EnWl => LanguageCodeDef::EnWl,
EnZa => LanguageCodeDef::EnZa,
EsEs => LanguageCodeDef::EsEs,
EsUs => LanguageCodeDef::EsUs,
EuEs => LanguageCodeDef::EuEs,
FaIr => LanguageCodeDef::FaIr,
FiFi => LanguageCodeDef::FiFi,
FrCa => LanguageCodeDef::FrCa,
FrFr => LanguageCodeDef::FrFr,
GlEs => LanguageCodeDef::GlEs,
HeIl => LanguageCodeDef::HeIl,
HiIn => LanguageCodeDef::HiIn,
HrHr => LanguageCodeDef::HrHr,
IdId => LanguageCodeDef::IdId,
ItIt => LanguageCodeDef::ItIt,
JaJp => LanguageCodeDef::JaJp,
KoKr => LanguageCodeDef::KoKr,
LvLv => LanguageCodeDef::LvLv,
MsMy => LanguageCodeDef::MsMy,
NlNl => LanguageCodeDef::NlNl,
NoNo => LanguageCodeDef::NoNo,
PlPl => LanguageCodeDef::PlPl,
PtBr => LanguageCodeDef::PtBr,
PtPt => LanguageCodeDef::PtPt,
RoRo => LanguageCodeDef::RoRo,
RuRu => LanguageCodeDef::RuRu,
SkSk => LanguageCodeDef::SkSk,
SoSo => LanguageCodeDef::SoSo,
SrRs => LanguageCodeDef::SrRs,
SvSe => LanguageCodeDef::SvSe,
ThTh => LanguageCodeDef::ThTh,
TlPh => LanguageCodeDef::TlPh,
UkUa => LanguageCodeDef::UkUa,
ViVn => LanguageCodeDef::ViVn,
ZhCn => LanguageCodeDef::ZhCn,
ZhHk => LanguageCodeDef::ZhHk,
ZhTw => LanguageCodeDef::ZhTw,
ZuZa => LanguageCodeDef::ZuZa,
_ => LanguageCodeDef::Unknown,
}
.serialize(serializer)
}
}
#[serde_as]
#[derive(serde_derive::Serialize)]
struct ResultDef {
result_id: Option<String>,
start_time: f64,
end_time: f64,
is_partial: bool,
#[serde_as(as = "Option<Vec<AlternativeDef>>")]
alternatives: Option<Vec<sdk_types::Alternative>>,
channel_id: Option<String>,
#[serde_as(as = "Option<LanguageCodeDef>")]
language_code: Option<sdk_types::LanguageCode>,
#[serde_as(as = "Option<Vec<LanguageWithScoreDef>>")]
language_identification: Option<Vec<sdk_types::LanguageWithScore>>,
}
impl SerializeAs<sdk_types::Result> for ResultDef {
fn serialize_as<S>(value: &sdk_types::Result, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let def = ResultDef {
result_id: value.result_id.clone(),
start_time: value.start_time,
end_time: value.end_time,
is_partial: value.is_partial,
alternatives: value.alternatives.clone(),
channel_id: value.channel_id.clone(),
language_code: value.language_code.clone(),
language_identification: value.language_identification.clone(),
};
def.serialize(serializer)
}
}
#[serde_as]
#[derive(serde_derive::Serialize)]
pub struct TranscriptDef {
#[serde_as(as = "Option<Vec<ResultDef>>")]
pub results: Option<Vec<sdk_types::Result>>,
}

View file

@ -20,6 +20,7 @@ use futures::prelude::*;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use super::imp::{Settings, Transcriber}; use super::imp::{Settings, Transcriber};
use super::remote_types::TranscriptDef;
use super::CAT; use super::CAT;
#[derive(Debug)] #[derive(Debug)]
@ -81,16 +82,13 @@ impl TranscriptItem {
#[derive(Clone)] #[derive(Clone)]
pub enum TranscriptEvent { pub enum TranscriptEvent {
Items(Arc<Vec<TranscriptItem>>), Transcript {
items: Arc<Vec<TranscriptItem>>,
serialized: Option<String>,
},
Eos, Eos,
} }
impl From<Vec<TranscriptItem>> for TranscriptEvent {
fn from(transcript_items: Vec<TranscriptItem>) -> Self {
TranscriptEvent::Items(transcript_items.into())
}
}
struct DiscontOffsetTracker { struct DiscontOffsetTracker {
discont_offset: gst::ClockTime, discont_offset: gst::ClockTime,
last_chained_buffer_rtime: Option<gst::ClockTime>, last_chained_buffer_rtime: Option<gst::ClockTime>,
@ -198,29 +196,36 @@ impl TranscriberStream {
return Ok(TranscriptEvent::Eos); return Ok(TranscriptEvent::Eos);
}; };
if let types::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { if let types::TranscriptResultStream::TranscriptEvent(mut transcript_evt) = event {
let mut ready_items = None; let Some(ref mut transcript) = transcript_evt.transcript else {
continue;
};
if let Some(result) = transcript_evt let t = TranscriptDef {
.transcript results: transcript.results.clone(),
.and_then(|transcript| transcript.results) };
.and_then(|mut results| results.drain(..).next())
{
gst::trace!(CAT, imp = self.imp, "Received: {result:?}");
if let Some(alternative) = result let serialized = serde_json::to_string(&t).expect("serializable");
.alternatives
.and_then(|mut alternatives| alternatives.drain(..).next())
{
ready_items = alternative.items.and_then(|items| {
self.get_ready_transcript_items(items, result.is_partial)
});
}
}
if let Some(ready_items) = ready_items { let Some(result) = transcript
return Ok(ready_items.into()); .results
} .as_mut()
.and_then(|results| results.drain(..).next())
else {
continue;
};
let ready_items = result
.alternatives
.and_then(|mut alternatives| alternatives.drain(..).next())
.and_then(|alternative| alternative.items)
.map(|items| self.get_ready_transcript_items(items, result.is_partial))
.unwrap_or(vec![]);
return Ok(TranscriptEvent::Transcript {
items: ready_items.into(),
serialized: Some(serialized),
});
} else { } else {
gst::warning!( gst::warning!(
CAT, CAT,
@ -236,7 +241,9 @@ impl TranscriberStream {
&mut self, &mut self,
mut items: Vec<types::Item>, mut items: Vec<types::Item>,
partial: bool, partial: bool,
) -> Option<Vec<TranscriptItem>> { ) -> Vec<TranscriptItem> {
let mut output = vec![];
if items.len() < self.partial_index { if items.len() < self.partial_index {
gst::error!( gst::error!(
CAT, CAT,
@ -250,11 +257,9 @@ impl TranscriberStream {
self.partial_index = 0; self.partial_index = 0;
} }
return None; return output;
} }
let mut output = vec![];
for item in items.drain(self.partial_index..) { for item in items.drain(self.partial_index..) {
if !item.stable().unwrap_or(false) { if !item.stable().unwrap_or(false) {
break; break;
@ -281,10 +286,6 @@ impl TranscriberStream {
self.partial_index = 0; self.partial_index = 0;
} }
if output.is_empty() { output
return None;
}
Some(output)
} }
} }

View file

@ -41,6 +41,12 @@ impl From<&TranscriptItem> for TranslatedItem {
} }
} }
#[derive(Debug)]
pub struct Translation {
pub items: Vec<TranslatedItem>,
pub translation: String,
}
pub struct TranslateLoop { pub struct TranslateLoop {
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>, pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
client: aws_translate::Client, client: aws_translate::Client,
@ -48,7 +54,7 @@ pub struct TranslateLoop {
output_lang: String, output_lang: String,
tokenization_method: TranslationTokenizationMethod, tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>, transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>, translate_tx: mpsc::Sender<Translation>,
} }
impl TranslateLoop { impl TranslateLoop {
@ -59,7 +65,7 @@ impl TranslateLoop {
output_lang: &str, output_lang: &str,
tokenization_method: TranslationTokenizationMethod, tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>, transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>, translate_tx: mpsc::Sender<Translation>,
) -> Self { ) -> Self {
let aws_config = imp.aws_config.lock().unwrap(); let aws_config = imp.aws_config.lock().unwrap();
let aws_config = aws_config let aws_config = aws_config
@ -171,7 +177,7 @@ impl TranslateLoop {
vec![TranslatedItem { vec![TranslatedItem {
pts: first_pts, pts: first_pts,
duration: last_pts.saturating_sub(first_pts) + last_duration, duration: last_pts.saturating_sub(first_pts) + last_duration,
content: translated_text, content: translated_text.clone(),
}] }]
} }
Tokenization::SpanBased => span_tokenize_items(&translated_text, ts_duration_list), Tokenization::SpanBased => span_tokenize_items(&translated_text, ts_duration_list),
@ -179,7 +185,15 @@ impl TranslateLoop {
gst::trace!(CAT, imp = self.pad, "Sending {translated_items:?}"); gst::trace!(CAT, imp = self.pad, "Sending {translated_items:?}");
if self.translate_tx.send(translated_items).await.is_err() { if self
.translate_tx
.send(Translation {
items: translated_items,
translation: translated_text,
})
.await
.is_err()
{
gst::info!( gst::info!(
CAT, CAT,
imp = self.pad, imp = self.pad,