From b5bd7d047cad7d9974f2daac57a8ee8e3002ad72 Mon Sep 17 00:00:00 2001 From: Mathieu Duponchelle Date: Wed, 13 Nov 2024 15:55:35 +0100 Subject: [PATCH] awstranscribe: output original transcripts to separate pad When the transcriber is used in a live situation, it can be useful to save a transcript for editing after the fact when producing a VOD. Each source pad now gets an "unsynced_" pendant. That unsynced pad is pushed to from the context of the "live" source pad task. Flow returns from the unsynced pads are ignored, we simply check the last flow return before attempting to push the next transcript. Part-of: --- Cargo.lock | 1 + docs/plugins/gst_plugins_cache.json | 10 + net/aws/Cargo.toml | 1 + net/aws/src/transcriber/imp.rs | 177 ++++++++++++-- net/aws/src/transcriber/mod.rs | 1 + net/aws/src/transcriber/remote_types.rs | 305 ++++++++++++++++++++++++ net/aws/src/transcriber/transcribe.rs | 73 +++--- net/aws/src/transcriber/translate.rs | 22 +- 8 files changed, 536 insertions(+), 54 deletions(-) create mode 100644 net/aws/src/transcriber/remote_types.rs diff --git a/Cargo.lock b/Cargo.lock index 86a3a5ea..f0cbbb8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2458,6 +2458,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", + "serde_with", "sprintf", "test-with", "tokio", diff --git a/docs/plugins/gst_plugins_cache.json b/docs/plugins/gst_plugins_cache.json index d014c3b2..5a635a87 100644 --- a/docs/plugins/gst_plugins_cache.json +++ b/docs/plugins/gst_plugins_cache.json @@ -1162,6 +1162,16 @@ "direction": "src", "presence": "request", "type": "GstTranslateSrcPad" + }, + "unsynced_src": { + "caps": "application/x-json:\n", + "direction": "src", + "presence": "always" + }, + "unsynced_translate_src_%%u": { + "caps": "application/x-json:\n", + "direction": "src", + "presence": "request" } }, "properties": { diff --git a/net/aws/Cargo.toml b/net/aws/Cargo.toml index f9f90c2f..b3bfd6e5 100644 --- a/net/aws/Cargo.toml +++ b/net/aws/Cargo.toml @@ -31,6 +31,7 @@ tokio = { version = "1.0", features = [ "full" ] } serde = "1" serde_derive = "1" serde_json = "1" +serde_with = "3" url = "2" gst-video = { workspace = true, features = ["v1_22"] } sprintf = "0.2" diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index 2ba32452..4dba1587 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -35,7 +35,7 @@ use std::sync::{Arc, Mutex}; use std::sync::LazyLock; use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem}; -use super::translate::{TranslateLoop, TranslatedItem}; +use super::translate::{TranslateLoop, TranslatedItem, Translation}; use super::{ AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod, TranslationTokenizationMethod, CAT, @@ -462,9 +462,16 @@ impl Transcriber { use TranscriptEvent::*; match res { None => (), - Some(Items(items)) => { + Some(Transcript { + items, + serialized + }) => { if imp.transcript_event_tx.receiver_count() > 0 { - let _ = imp.transcript_event_tx.send(Items(items.clone())); + let _ = imp.transcript_event_tx.send( + Transcript { + items: items.clone(), + serialized: serialized.clone() + }); } if imp.transcript_event_for_translate_tx.receiver_count() > 0 { @@ -472,7 +479,10 @@ impl Transcriber { if let Some(items_to_translate) = translate_queue.push(item) { let _ = imp .transcript_event_for_translate_tx - .send(Items(items_to_translate.into())); + .send(Transcript { + items: items_to_translate.into(), + serialized: None, + }); } } } @@ -489,7 +499,10 @@ impl Transcriber { translate_queue.drain().collect(); let _ = imp .transcript_event_for_translate_tx - .send(Items(items_to_translate.into())); + .send(Transcript { + items: items_to_translate.into(), + serialized: None, + }); let _ = imp.transcript_event_for_translate_tx.send(Eos); } @@ -518,7 +531,10 @@ impl Transcriber { ); let _ = imp .transcript_event_for_translate_tx - .send(Items(items_to_translate.into())); + .send(Transcript { + items: items_to_translate.into(), + serialized: None, + }); } } } @@ -662,12 +678,20 @@ impl ObjectSubclass for Transcriber { .flags(gst::PadFlags::FIXED_CAPS) .build(); + let templ = klass.pad_template("unsynced_src").unwrap(); + let static_unsynced_srcpad = gst::PadBuilder::::from_template(&templ) + .flags(gst::PadFlags::FIXED_CAPS) + .build(); // Setting the channel capacity so that a TranslateSrcPad that would lag // behind for some reasons get a chance to catch-up without loosing items. // Receiver will be created by subscribing to sender later. let (transcript_event_for_translate_tx, _) = broadcast::channel(128); let (transcript_event_tx, _) = broadcast::channel(128); + static_srcpad + .imp() + .set_unsynced_pad(&static_unsynced_srcpad); + Self { static_srcpad, sinkpad, @@ -785,6 +809,17 @@ impl ObjectImpl for Transcriber { let obj = self.obj(); obj.add_pad(&self.sinkpad).unwrap(); obj.add_pad(&self.static_srcpad).unwrap(); + obj.add_pad( + self.static_srcpad + .imp() + .state + .lock() + .unwrap() + .unsynced_pad + .as_ref() + .unwrap(), + ) + .unwrap(); obj.set_element_flags(gst::ElementFlags::PROVIDE_CLOCK | gst::ElementFlags::REQUIRE_CLOCK); } @@ -958,6 +993,21 @@ impl ElementImpl for Transcriber { super::TranslateSrcPad::static_type(), ) .unwrap(); + let src_caps = gst::Caps::builder("application/x-json").build(); + let unsynced_src_pad_template = gst::PadTemplate::new( + "unsynced_src", + gst::PadDirection::Src, + gst::PadPresence::Always, + &src_caps, + ) + .unwrap(); + let unsynced_sometimes_src_pad_template = gst::PadTemplate::new( + "unsynced_translate_src_%u", + gst::PadDirection::Src, + gst::PadPresence::Request, + &src_caps, + ) + .unwrap(); let sink_caps = gst_audio::AudioCapsBuilder::new() .format(gst_audio::AudioFormat::S16le) @@ -972,7 +1022,13 @@ impl ElementImpl for Transcriber { ) .unwrap(); - vec![src_pad_template, req_src_pad_template, sink_pad_template] + vec![ + src_pad_template, + req_src_pad_template, + unsynced_src_pad_template, + unsynced_sometimes_src_pad_template, + sink_pad_template, + ] }); PAD_TEMPLATES.as_ref() @@ -1041,12 +1097,25 @@ impl ElementImpl for Transcriber { .flags(gst::PadFlags::FIXED_CAPS) .build(); + let templ = self + .obj() + .class() + .pad_template("unsynced_translate_src_%u") + .unwrap(); + let static_unsynced_srcpad = gst::PadBuilder::::from_template(&templ) + .name(format!("unsynced_translate_src_{}", state.pad_serial).as_str()) + .flags(gst::PadFlags::FIXED_CAPS) + .build(); + + pad.imp().set_unsynced_pad(&static_unsynced_srcpad); + state.srcpads.insert(pad.clone()); state.pad_serial += 1; drop(state); self.obj().add_pad(&pad).unwrap(); + self.obj().add_pad(&static_unsynced_srcpad).unwrap(); let _ = self .obj() @@ -1059,6 +1128,7 @@ impl ElementImpl for Transcriber { fn release_pad(&self, pad: &gst::Pad) { pad.set_active(false).unwrap(); self.obj().remove_pad(pad).unwrap(); + self.state.lock().unwrap().srcpads.remove(pad); self.obj().child_removed(pad, &pad.name()); let _ = self @@ -1098,6 +1168,7 @@ impl ChildProxyImpl for Transcriber { .map(|p| p.upcast()) } } + struct TranslationPadTask { pad: glib::subclass::ObjectImplRef, elem: super::Transcriber, @@ -1105,13 +1176,15 @@ struct TranslationPadTask { needs_translate: bool, translate_loop_handle: Option>>, to_translate_tx: Option>>>, - from_translate_rx: Option>>, + from_translate_rx: Option>, send_events: bool, output_items: VecDeque, our_latency: gst::ClockTime, seqnum: gst::Seqnum, send_eos: bool, pending_translations: usize, + unsynced_pad: Option, + unsynced: Option, } impl TranslationPadTask { @@ -1189,6 +1262,8 @@ impl TranslationPadTask { seqnum: gst::Seqnum::next(), send_eos: false, pending_translations: 0, + unsynced_pad: pad.state.lock().unwrap().unsynced_pad.clone(), + unsynced: None, }) } } @@ -1234,8 +1309,12 @@ impl TranslationPadTask { use TranscriptEvent::*; use broadcast::error::RecvError; match items_res { - Ok(Items(transcript_items)) => { - self.output_items.extend(transcript_items.iter().map(Into::into)); + Ok(Transcript { + items, + serialized, + }) => { + self.unsynced = serialized; + self.output_items.extend(items.iter().map(Into::into)); } Ok(Eos) => { gst::debug!(CAT, imp = self.pad, "Got eos"); @@ -1282,7 +1361,10 @@ impl TranslationPadTask { use TranscriptEvent::*; use broadcast::error::RecvError; match items_res { - Ok(Items(items_to_translate)) => Some(items_to_translate), + Ok(Transcript { + items, + .. + }) => Some(items), Ok(Eos) => { gst::debug!(CAT, imp = self.pad, "Got eos"); self.send_eos = true; @@ -1328,15 +1410,24 @@ impl TranslationPadTask { .as_mut() .expect("from_translation chan must be available in translation mode"); - while let Ok(translated_items) = from_translate_rx.try_next() { - let Some(translated_items) = translated_items else { + while let Ok(translation) = from_translate_rx.try_next() { + let Some(translation) = translation else { const ERR: &str = "translation chan terminated"; gst::debug!(CAT, imp = self.pad, "{ERR}"); return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); }; + if let Some(pts) = translation.items.first().map(|i| i.pts) { + self.unsynced = Some( + serde_json::json!({ + "translation": translation.translation, + "start_time": *pts, + }) + .to_string(), + ); + } self.output_items - .extend(translated_items.into_iter().map(Into::into)); + .extend(translation.items.into_iter().map(Into::into)); self.pending_translations = self.pending_translations.saturating_sub(1); } @@ -1362,6 +1453,35 @@ impl TranslationPadTask { (last_position, state.discont_pending) }; + if let Some(unsynced) = self.unsynced.take() { + if let Some(ref unsynced_pad) = self.unsynced_pad { + if unsynced_pad.last_flow_result().is_ok() { + gst::log!( + CAT, + obj = unsynced_pad, + "pushing serialized transcript with timestamp {now}" + ); + gst::trace!(CAT, obj = unsynced_pad, "serialized transcript: {unsynced}"); + + let mut buf = gst::Buffer::from_mut_slice(unsynced.into_bytes()); + { + let buf_mut = buf.get_mut().unwrap(); + + buf_mut.set_pts(now); + } + + let _ = unsynced_pad.push(buf); + } else { + gst::log!( + CAT, + obj = unsynced_pad, + "not pushing serialized transcript, last flow result: {:?}", + unsynced_pad.last_flow_result() + ); + } + } + } + /* First, check our pending buffers */ while let Some(item) = self.output_items.front() { // Note: items pts start from 0 + lateness @@ -1531,6 +1651,7 @@ impl TranslationPadTask { } let mut events = vec![]; + let mut unsynced_events = vec![]; { let elem_imp = self.elem.imp(); @@ -1548,16 +1669,31 @@ impl TranslationPadTask { .build(), ); + unsynced_events.push( + gst::event::StreamStart::builder("unsynced-transcription") + .seqnum(self.seqnum) + .build(), + ); + let caps = gst::Caps::builder("text/x-raw") .field("format", "utf8") .build(); events.push(gst::event::Caps::builder(&caps).seqnum(self.seqnum).build()); + let caps = gst::Caps::builder("application/x-json").build(); + unsynced_events.push(gst::event::Caps::builder(&caps).seqnum(self.seqnum).build()); + events.push( gst::event::Segment::builder(&pad_state.out_segment) .seqnum(self.seqnum) .build(), ); + + unsynced_events.push( + gst::event::Segment::builder(&pad_state.out_segment) + .seqnum(self.seqnum) + .build(), + ); } for event in events.drain(..) { @@ -1569,6 +1705,13 @@ impl TranslationPadTask { } } + if let Some(ref unsynced_pad) = self.unsynced_pad { + for event in unsynced_events.drain(..) { + gst::info!(CAT, obj = unsynced_pad, "Sending {event:?}"); + let _ = unsynced_pad.push_event(event); + } + } + self.send_events = false; Ok(()) @@ -1581,6 +1724,7 @@ struct TranslationPadState { out_segment: gst::FormattedSegment, task_abort_handle: Option, start_time: Option, + unsynced_pad: Option, } impl Default for TranslationPadState { @@ -1590,6 +1734,7 @@ impl Default for TranslationPadState { out_segment: Default::default(), task_abort_handle: None, start_time: None, + unsynced_pad: None, } } } @@ -1607,6 +1752,10 @@ pub struct TranslateSrcPad { } impl TranslateSrcPad { + fn set_unsynced_pad(&self, pad: &gst::Pad) { + self.state.lock().unwrap().unsynced_pad = Some(pad.clone()); + } + fn start_task(&self) -> Result<(), gst::LoggableError> { gst::debug!(CAT, imp = self, "Starting task"); diff --git a/net/aws/src/transcriber/mod.rs b/net/aws/src/transcriber/mod.rs index c12386d4..8e9f9c76 100644 --- a/net/aws/src/transcriber/mod.rs +++ b/net/aws/src/transcriber/mod.rs @@ -10,6 +10,7 @@ use gst::glib; use gst::prelude::*; mod imp; +mod remote_types; mod transcribe; mod translate; diff --git a/net/aws/src/transcriber/remote_types.rs b/net/aws/src/transcriber/remote_types.rs new file mode 100644 index 00000000..e71fbcaf --- /dev/null +++ b/net/aws/src/transcriber/remote_types.rs @@ -0,0 +1,305 @@ +use aws_sdk_transcribestreaming::types as sdk_types; + +use serde::{Serialize, Serializer}; +use serde_with::{serde_as, SerializeAs}; + +#[serde_as] +#[derive(serde_derive::Serialize)] +struct EntityDef { + start_time: f64, + end_time: f64, + category: Option, + r#type: Option, + content: Option, + confidence: Option, +} + +impl SerializeAs for EntityDef { + fn serialize_as(value: &sdk_types::Entity, serializer: S) -> Result + where + S: Serializer, + { + EntityDef { + start_time: value.start_time, + end_time: value.end_time, + category: value.category.clone(), + r#type: value.r#type.clone(), + content: value.content.clone(), + confidence: value.confidence, + } + .serialize(serializer) + } +} + +#[serde_as] +#[derive(serde_derive::Serialize)] +enum ItemTypeDef { + Pronunciation, + Punctuation, + Unknown, +} + +impl SerializeAs for ItemTypeDef { + fn serialize_as(value: &sdk_types::ItemType, serializer: S) -> Result + where + S: Serializer, + { + use sdk_types::ItemType::*; + match value { + Pronunciation => ItemTypeDef::Pronunciation, + Punctuation => ItemTypeDef::Punctuation, + _ => ItemTypeDef::Unknown, + } + .serialize(serializer) + } +} + +#[serde_as] +#[derive(serde_derive::Serialize)] +struct ItemDef { + start_time: f64, + end_time: f64, + #[serde_as(as = "Option")] + r#type: Option, + content: Option, + vocabulary_filter_match: bool, + speaker: Option, + confidence: Option, + stable: Option, +} + +impl SerializeAs for ItemDef { + fn serialize_as(value: &sdk_types::Item, serializer: S) -> Result + where + S: Serializer, + { + ItemDef { + start_time: value.start_time, + end_time: value.end_time, + r#type: value.r#type.clone(), + content: value.content.clone(), + vocabulary_filter_match: value.vocabulary_filter_match, + speaker: value.speaker.clone(), + confidence: value.confidence, + stable: value.stable, + } + .serialize(serializer) + } +} + +#[serde_as] +#[derive(serde_derive::Serialize)] +struct AlternativeDef { + transcript: Option, + #[serde_as(as = "Option>")] + items: Option>, + #[serde_as(as = "Option>")] + entities: Option>, +} + +impl SerializeAs for AlternativeDef { + fn serialize_as(value: &sdk_types::Alternative, serializer: S) -> Result + where + S: Serializer, + { + AlternativeDef { + transcript: value.transcript.clone(), + items: value.items.clone(), + entities: value.entities.clone(), + } + .serialize(serializer) + } +} + +#[serde_as] +#[derive(serde_derive::Serialize)] +struct LanguageWithScoreDef { + #[serde_as(as = "Option")] + language_code: Option, + score: f64, +} + +impl SerializeAs for LanguageWithScoreDef { + fn serialize_as( + value: &sdk_types::LanguageWithScore, + serializer: S, + ) -> Result + where + S: Serializer, + { + LanguageWithScoreDef { + language_code: value.language_code.clone(), + score: value.score, + } + .serialize(serializer) + } +} + +#[serde_as] +#[derive(serde_derive::Serialize)] +enum LanguageCodeDef { + AfZa, + ArAe, + ArSa, + CaEs, + CsCz, + DaDk, + DeCh, + DeDe, + ElGr, + EnAb, + EnAu, + EnGb, + EnIe, + EnIn, + EnNz, + EnUs, + EnWl, + EnZa, + EsEs, + EsUs, + EuEs, + FaIr, + FiFi, + FrCa, + FrFr, + GlEs, + HeIl, + HiIn, + HrHr, + IdId, + ItIt, + JaJp, + KoKr, + LvLv, + MsMy, + NlNl, + NoNo, + PlPl, + PtBr, + PtPt, + RoRo, + RuRu, + SkSk, + SoSo, + SrRs, + SvSe, + ThTh, + TlPh, + UkUa, + ViVn, + ZhCn, + ZhHk, + ZhTw, + ZuZa, + Unknown, +} + +impl SerializeAs for LanguageCodeDef { + fn serialize_as(value: &sdk_types::LanguageCode, serializer: S) -> Result + where + S: Serializer, + { + use sdk_types::LanguageCode::*; + match value { + AfZa => LanguageCodeDef::AfZa, + ArAe => LanguageCodeDef::ArAe, + ArSa => LanguageCodeDef::ArSa, + CaEs => LanguageCodeDef::CaEs, + CsCz => LanguageCodeDef::CsCz, + DaDk => LanguageCodeDef::DaDk, + DeCh => LanguageCodeDef::DeCh, + DeDe => LanguageCodeDef::DeDe, + ElGr => LanguageCodeDef::ElGr, + EnAb => LanguageCodeDef::EnAb, + EnAu => LanguageCodeDef::EnAu, + EnGb => LanguageCodeDef::EnGb, + EnIe => LanguageCodeDef::EnIe, + EnIn => LanguageCodeDef::EnIn, + EnNz => LanguageCodeDef::EnNz, + EnUs => LanguageCodeDef::EnUs, + EnWl => LanguageCodeDef::EnWl, + EnZa => LanguageCodeDef::EnZa, + EsEs => LanguageCodeDef::EsEs, + EsUs => LanguageCodeDef::EsUs, + EuEs => LanguageCodeDef::EuEs, + FaIr => LanguageCodeDef::FaIr, + FiFi => LanguageCodeDef::FiFi, + FrCa => LanguageCodeDef::FrCa, + FrFr => LanguageCodeDef::FrFr, + GlEs => LanguageCodeDef::GlEs, + HeIl => LanguageCodeDef::HeIl, + HiIn => LanguageCodeDef::HiIn, + HrHr => LanguageCodeDef::HrHr, + IdId => LanguageCodeDef::IdId, + ItIt => LanguageCodeDef::ItIt, + JaJp => LanguageCodeDef::JaJp, + KoKr => LanguageCodeDef::KoKr, + LvLv => LanguageCodeDef::LvLv, + MsMy => LanguageCodeDef::MsMy, + NlNl => LanguageCodeDef::NlNl, + NoNo => LanguageCodeDef::NoNo, + PlPl => LanguageCodeDef::PlPl, + PtBr => LanguageCodeDef::PtBr, + PtPt => LanguageCodeDef::PtPt, + RoRo => LanguageCodeDef::RoRo, + RuRu => LanguageCodeDef::RuRu, + SkSk => LanguageCodeDef::SkSk, + SoSo => LanguageCodeDef::SoSo, + SrRs => LanguageCodeDef::SrRs, + SvSe => LanguageCodeDef::SvSe, + ThTh => LanguageCodeDef::ThTh, + TlPh => LanguageCodeDef::TlPh, + UkUa => LanguageCodeDef::UkUa, + ViVn => LanguageCodeDef::ViVn, + ZhCn => LanguageCodeDef::ZhCn, + ZhHk => LanguageCodeDef::ZhHk, + ZhTw => LanguageCodeDef::ZhTw, + ZuZa => LanguageCodeDef::ZuZa, + _ => LanguageCodeDef::Unknown, + } + .serialize(serializer) + } +} + +#[serde_as] +#[derive(serde_derive::Serialize)] +struct ResultDef { + result_id: Option, + start_time: f64, + end_time: f64, + is_partial: bool, + #[serde_as(as = "Option>")] + alternatives: Option>, + channel_id: Option, + #[serde_as(as = "Option")] + language_code: Option, + #[serde_as(as = "Option>")] + language_identification: Option>, +} + +impl SerializeAs for ResultDef { + fn serialize_as(value: &sdk_types::Result, serializer: S) -> Result + where + S: Serializer, + { + let def = ResultDef { + result_id: value.result_id.clone(), + start_time: value.start_time, + end_time: value.end_time, + is_partial: value.is_partial, + alternatives: value.alternatives.clone(), + channel_id: value.channel_id.clone(), + language_code: value.language_code.clone(), + language_identification: value.language_identification.clone(), + }; + def.serialize(serializer) + } +} + +#[serde_as] +#[derive(serde_derive::Serialize)] +pub struct TranscriptDef { + #[serde_as(as = "Option>")] + pub results: Option>, +} diff --git a/net/aws/src/transcriber/transcribe.rs b/net/aws/src/transcriber/transcribe.rs index 10ed4763..3da82666 100644 --- a/net/aws/src/transcriber/transcribe.rs +++ b/net/aws/src/transcriber/transcribe.rs @@ -20,6 +20,7 @@ use futures::prelude::*; use std::sync::{Arc, Mutex}; use super::imp::{Settings, Transcriber}; +use super::remote_types::TranscriptDef; use super::CAT; #[derive(Debug)] @@ -81,16 +82,13 @@ impl TranscriptItem { #[derive(Clone)] pub enum TranscriptEvent { - Items(Arc>), + Transcript { + items: Arc>, + serialized: Option, + }, Eos, } -impl From> for TranscriptEvent { - fn from(transcript_items: Vec) -> Self { - TranscriptEvent::Items(transcript_items.into()) - } -} - struct DiscontOffsetTracker { discont_offset: gst::ClockTime, last_chained_buffer_rtime: Option, @@ -198,29 +196,36 @@ impl TranscriberStream { return Ok(TranscriptEvent::Eos); }; - if let types::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { - let mut ready_items = None; + if let types::TranscriptResultStream::TranscriptEvent(mut transcript_evt) = event { + let Some(ref mut transcript) = transcript_evt.transcript else { + continue; + }; - if let Some(result) = transcript_evt - .transcript - .and_then(|transcript| transcript.results) - .and_then(|mut results| results.drain(..).next()) - { - gst::trace!(CAT, imp = self.imp, "Received: {result:?}"); + let t = TranscriptDef { + results: transcript.results.clone(), + }; - if let Some(alternative) = result - .alternatives - .and_then(|mut alternatives| alternatives.drain(..).next()) - { - ready_items = alternative.items.and_then(|items| { - self.get_ready_transcript_items(items, result.is_partial) - }); - } - } + let serialized = serde_json::to_string(&t).expect("serializable"); - if let Some(ready_items) = ready_items { - return Ok(ready_items.into()); - } + let Some(result) = transcript + .results + .as_mut() + .and_then(|results| results.drain(..).next()) + else { + continue; + }; + + let ready_items = result + .alternatives + .and_then(|mut alternatives| alternatives.drain(..).next()) + .and_then(|alternative| alternative.items) + .map(|items| self.get_ready_transcript_items(items, result.is_partial)) + .unwrap_or(vec![]); + + return Ok(TranscriptEvent::Transcript { + items: ready_items.into(), + serialized: Some(serialized), + }); } else { gst::warning!( CAT, @@ -236,7 +241,9 @@ impl TranscriberStream { &mut self, mut items: Vec, partial: bool, - ) -> Option> { + ) -> Vec { + let mut output = vec![]; + if items.len() < self.partial_index { gst::error!( CAT, @@ -250,11 +257,9 @@ impl TranscriberStream { self.partial_index = 0; } - return None; + return output; } - let mut output = vec![]; - for item in items.drain(self.partial_index..) { if !item.stable().unwrap_or(false) { break; @@ -281,10 +286,6 @@ impl TranscriberStream { self.partial_index = 0; } - if output.is_empty() { - return None; - } - - Some(output) + output } } diff --git a/net/aws/src/transcriber/translate.rs b/net/aws/src/transcriber/translate.rs index bcb2b08c..3ebda47d 100644 --- a/net/aws/src/transcriber/translate.rs +++ b/net/aws/src/transcriber/translate.rs @@ -41,6 +41,12 @@ impl From<&TranscriptItem> for TranslatedItem { } } +#[derive(Debug)] +pub struct Translation { + pub items: Vec, + pub translation: String, +} + pub struct TranslateLoop { pad: glib::subclass::ObjectImplRef, client: aws_translate::Client, @@ -48,7 +54,7 @@ pub struct TranslateLoop { output_lang: String, tokenization_method: TranslationTokenizationMethod, transcript_rx: mpsc::Receiver>>, - translate_tx: mpsc::Sender>, + translate_tx: mpsc::Sender, } impl TranslateLoop { @@ -59,7 +65,7 @@ impl TranslateLoop { output_lang: &str, tokenization_method: TranslationTokenizationMethod, transcript_rx: mpsc::Receiver>>, - translate_tx: mpsc::Sender>, + translate_tx: mpsc::Sender, ) -> Self { let aws_config = imp.aws_config.lock().unwrap(); let aws_config = aws_config @@ -171,7 +177,7 @@ impl TranslateLoop { vec![TranslatedItem { pts: first_pts, duration: last_pts.saturating_sub(first_pts) + last_duration, - content: translated_text, + content: translated_text.clone(), }] } Tokenization::SpanBased => span_tokenize_items(&translated_text, ts_duration_list), @@ -179,7 +185,15 @@ impl TranslateLoop { gst::trace!(CAT, imp = self.pad, "Sending {translated_items:?}"); - if self.translate_tx.send(translated_items).await.is_err() { + if self + .translate_tx + .send(Translation { + items: translated_items, + translation: translated_text, + }) + .await + .is_err() + { gst::info!( CAT, imp = self.pad,