diff --git a/docs/plugins/gst_plugins_cache.json b/docs/plugins/gst_plugins_cache.json index 2ac55be0..7f405d53 100644 --- a/docs/plugins/gst_plugins_cache.json +++ b/docs/plugins/gst_plugins_cache.json @@ -619,7 +619,7 @@ "rank": "none" }, "awstranscriber": { - "author": "Jordan Petridis , Mathieu Duponchelle ", + "author": "Jordan Petridis , Mathieu Duponchelle , François Laignel ", "description": "Speech to Text filter, using AWS transcribe", "hierarchy": [ "GstAwsTranscriber", diff --git a/net/aws/Cargo.toml b/net/aws/Cargo.toml index ef6370e9..c329009f 100644 --- a/net/aws/Cargo.toml +++ b/net/aws/Cargo.toml @@ -11,36 +11,30 @@ edition = "2021" rust-version = "1.66" [dependencies] -bytes = "1.0" -futures = "0.3" -gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } -gst-base = { package = "gstreamer-base", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } -gst-audio = { package = "gstreamer-audio", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] } +async-stream = "0.3.4" +base32 = "0.4" aws-config = "0.54.0" aws-sdk-s3 = "0.24.0" -aws-sdk-transcribe = "0.24.0" +aws-sdk-transcribestreaming = "0.24.0" aws-types = "0.54.0" aws-credential-types = "0.54.0" aws-sig-auth = "0.54.0" aws-smithy-http = { version = "0.54.0", features = [ "rt-tokio" ] } aws-smithy-types = "0.54.0" +bytes = "1.0" +futures = "0.3" +gio = { git = "https://github.com/gtk-rs/gtk-rs-core.git", package = "gio" } +gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } +gst-base = { package = "gstreamer-base", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } +gst-audio = { package = "gstreamer-audio", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] } http = "0.2.7" -chrono = "0.4" -url = "2" +once_cell = "1.0" percent-encoding = "2" tokio = { version = "1.0", features = [ "full" ] } -async-tungstenite = { version = "0.20", features = ["tokio", "tokio-runtime", "tokio-native-tls"] } -nom = "7" -crc = "3" -byteorder = "1.3.4" -once_cell = "1.0" serde = "1" serde_derive = "1" serde_json = "1" -atomic_refcell = "0.1" -base32 = "0.4" -backoff = { version = "0.4", features = [ "futures", "tokio" ] } -gio = { git = "https://github.com/gtk-rs/gtk-rs-core.git", package = "gio" } +url = "2" [dev-dependencies] chrono = { version = "0.4", features = [ "alloc" ] } diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index 3918c61c..18e45c08 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -1,4 +1,5 @@ // Copyright (C) 2020 Mathieu Duponchelle +// Copyright (C) 2023 François Laignel // // This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. // If a copy of the MPL was not distributed with this file, You can obtain one at @@ -9,91 +10,22 @@ use gst::glib; use gst::prelude::*; use gst::subclass::prelude::*; -use gst::{element_imp_error, error_msg, loggable_error}; -use std::default::Default; +use aws_sdk_transcribestreaming as aws_transcribe; +use aws_sdk_transcribestreaming::model; -use aws_config::default_provider::credentials::DefaultCredentialsChain; -use aws_credential_types::{provider::ProvideCredentials, Credentials}; -use aws_sig_auth::signer::{self, HttpSignatureType, OperationSigningConfig, RequestConfig}; -use aws_smithy_http::body::SdkBody; -use aws_types::region::{Region, SigningRegion}; -use aws_types::SigningService; -use std::time::{Duration, SystemTime}; - -use chrono::prelude::*; -use http::Uri; - -use async_tungstenite::tungstenite::error::Error as WsError; -use async_tungstenite::{tokio::connect_async, tungstenite::Message}; use futures::channel::mpsc; -use futures::future::{abortable, AbortHandle}; use futures::prelude::*; -use tokio::runtime; +use tokio::{runtime, task}; use std::cmp::Ordering; use std::collections::VecDeque; -use std::pin::Pin; use std::sync::Mutex; -use atomic_refcell::AtomicRefCell; - -use super::packet::*; - -use serde_derive::{Deserialize, Serialize}; - use once_cell::sync::Lazy; use super::{AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod}; -const DEFAULT_TRANSCRIBER_REGION: &str = "us-east-1"; - -#[derive(Deserialize, Serialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct TranscriptItem { - content: String, - end_time: f32, - start_time: f32, - #[serde(rename = "Type")] - type_: String, - stable: bool, -} - -#[derive(Deserialize, Serialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct TranscriptAlternative { - items: Vec, - transcript: String, -} - -#[derive(Deserialize, Serialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct TranscriptResult { - alternatives: Vec, - end_time: f32, - start_time: f32, - is_partial: bool, - result_id: String, -} - -#[derive(Deserialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct TranscriptTranscript { - results: Vec, -} - -#[derive(Deserialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct Transcript { - transcript: TranscriptTranscript, -} - -#[derive(Deserialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct ExceptionMessage { - message: String, -} - static CAT: Lazy = Lazy::new(|| { gst::DebugCategory::new( "awstranscribe", @@ -110,8 +42,10 @@ static RUNTIME: Lazy = Lazy::new(|| { .unwrap() }); +const DEFAULT_TRANSCRIBER_REGION: &str = "us-east-1"; const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8); const DEFAULT_LATENESS: gst::ClockTime = gst::ClockTime::ZERO; +const DEFAULT_LANGUAGE_CODE: &str = "en-US"; const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low; const DEFAULT_VOCABULARY_FILTER_METHOD: AwsTranscriberVocabularyFilterMethod = AwsTranscriberVocabularyFilterMethod::Mask; @@ -121,7 +55,7 @@ const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100); struct Settings { latency: gst::ClockTime, lateness: gst::ClockTime, - language_code: Option, + language_code: String, vocabulary: Option, session_id: Option, results_stability: AwsTranscriberResultStability, @@ -132,12 +66,12 @@ struct Settings { vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod, } -impl Default for Settings { +impl std::default::Default for Settings { fn default() -> Self { Self { latency: DEFAULT_LATENCY, lateness: DEFAULT_LATENESS, - language_code: Some("en-US".to_string()), + language_code: DEFAULT_LANGUAGE_CODE.to_string(), vocabulary: None, session_id: None, results_stability: DEFAULT_STABILITY, @@ -150,29 +84,55 @@ impl Default for Settings { } } +#[derive(Debug)] +struct TranscriptionSettings { + lang_code: model::LanguageCode, + sample_rate: i32, + vocabulary: Option, + vocabulary_filter: Option, + vocabulary_filter_method: model::VocabularyFilterMethod, + session_id: Option, + results_stability: model::PartialResultsStability, +} + +impl TranscriptionSettings { + fn from(settings: &Settings, sample_rate: i32) -> Self { + TranscriptionSettings { + lang_code: settings.language_code.as_str().into(), + sample_rate, + vocabulary: settings.vocabulary.clone(), + vocabulary_filter: settings.vocabulary_filter.clone(), + vocabulary_filter_method: settings.vocabulary_filter_method.into(), + session_id: settings.session_id.clone(), + results_stability: settings.results_stability.into(), + } + } +} + struct State { - connected: bool, - sender: Option>, - recv_abort_handle: Option, - send_abort_handle: Option, + client: Option, + buffer_tx: Option>, + transcript_tx: Option>, + ws_loop_handle: Option>>, in_segment: gst::FormattedSegment, out_segment: gst::FormattedSegment, seqnum: gst::Seqnum, buffers: VecDeque, send_eos: bool, + // FIXME never set to true discont: bool, partial_index: usize, send_events: bool, start_time: Option, } -impl Default for State { +impl std::default::Default for State { fn default() -> Self { Self { - connected: false, - sender: None, - recv_abort_handle: None, - send_abort_handle: None, + client: None, + buffer_tx: None, + transcript_tx: None, + ws_loop_handle: None, in_segment: gst::FormattedSegment::new(), out_segment: gst::FormattedSegment::new(), seqnum: gst::Seqnum::next(), @@ -186,36 +146,11 @@ impl Default for State { } } -type WsSink = Pin + Send + Sync>>; - pub struct Transcriber { srcpad: gst::Pad, sinkpad: gst::Pad, settings: Mutex, state: Mutex, - ws_sink: AtomicRefCell>, -} - -fn build_packet(payload: &[u8]) -> Vec { - let headers = [ - Header { - name: ":event-type".into(), - value: "AudioEvent".into(), - value_type: 7, - }, - Header { - name: ":content-type".into(), - value: "application/octet-stream".into(), - value_type: 7, - }, - Header { - name: ":message-type".into(), - value: "event".into(), - value_type: 7, - }, - ]; - - encode_packet(payload, &headers).expect("foobar") } impl Transcriber { @@ -223,12 +158,7 @@ impl Transcriber { /* First, check our pending buffers */ let mut items = vec![]; - let now = match self.obj().current_running_time() { - Some(now) => now, - None => { - return true; - } - }; + let Some(now) = self.obj().current_running_time() else { return true }; let latency = self.settings.lock().unwrap().latency; @@ -249,9 +179,7 @@ impl Transcriber { gst::trace!( CAT, imp: self, - "Checking now {} if item is ready for dequeuing, PTS {}, threshold {} vs {}", - now, - pts, + "Checking now {now} if item is ready for dequeuing, PTS {pts}, threshold {} vs {}", pts + latency.saturating_sub(3 * GRANULARITY), now - start_time ); @@ -295,7 +223,7 @@ impl Transcriber { .duration(pts - last_position) .seqnum(seqnum) .build(); - gst::log!(CAT, "Pushing gap: {} -> {}", last_position, pts); + gst::log!(CAT, "Pushing gap: {last_position} -> {pts}"); if !self.srcpad.push_event(gap_event) { return false; } @@ -306,9 +234,7 @@ impl Transcriber { gst::warning!( CAT, imp: self, - "Updating item PTS ({} < {}), consider increasing latency", - pts, - last_position + "Updating item PTS ({pts} < {last_position}), consider increasing latency", ); pts = last_position; @@ -326,7 +252,7 @@ impl Transcriber { last_position = pts + duration; - gst::debug!(CAT, "Pushing buffer: {} -> {}", pts, pts + duration); + gst::debug!(CAT, "Pushing buffer: {pts} -> {}", pts + duration); if self.srcpad.push(buf).is_err() { return false; @@ -337,9 +263,7 @@ impl Transcriber { gst::trace!( CAT, imp: self, - "Checking now: {} if we need to push a gap, last_position: {}, threshold: {}", - now, - last_position, + "Checking now: {now} if we need to push a gap, last_position: {last_position}, threshold: {}", last_position + latency.saturating_sub(GRANULARITY) ); @@ -353,8 +277,7 @@ impl Transcriber { gst::log!( CAT, - "Pushing gap: {} -> {}", - last_position, + "Pushing gap: {last_position} -> {}", last_position + duration ); @@ -374,15 +297,15 @@ impl Transcriber { true } - fn enqueue(&self, state: &mut State, alternative: &TranscriptAlternative, partial: bool) { + fn enqueue(&self, state: &mut State, items: &[model::Item], partial: bool) { let lateness = self.settings.lock().unwrap().lateness; - if alternative.items.len() <= state.partial_index { + if items.len() <= state.partial_index { gst::error!( CAT, imp: self, "sanity check failed, alternative length {} < partial_index {}", - alternative.items.len(), + items.len(), state.partial_index ); @@ -393,40 +316,42 @@ impl Transcriber { return; } - for item in &alternative.items[state.partial_index..] { - let start_time = - ((item.start_time as f64 * 1_000_000_000.0) as u64).nseconds() + lateness; - let end_time = ((item.end_time as f64 * 1_000_000_000.0) as u64).nseconds() + lateness; + for item in &items[state.partial_index..] { + let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; + let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; - if !item.stable { + if !item.stable().unwrap_or(false) { break; } - /* Should be sent now */ - gst::debug!( - CAT, - imp: self, - "Item is ready for queuing: {}, PTS {}", - item.content, - start_time - ); - let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes()); + // FIXME could probably just unwrap + if let Some(content) = item.content() { + /* Should be sent now */ + gst::debug!( + CAT, + imp: self, + "Item is ready for queuing: {content}, PTS {start_time}", + ); - { - let buf = buf.get_mut().unwrap(); + let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes()); + { + let buf = buf.get_mut().unwrap(); - if state.discont { - buf.set_flags(gst::BufferFlags::DISCONT); - state.discont = false; + if state.discont { + buf.set_flags(gst::BufferFlags::DISCONT); + state.discont = false; + } + + buf.set_pts(start_time); + buf.set_duration(end_time - start_time); } - buf.set_pts(start_time); - buf.set_duration(end_time - start_time); + state.partial_index += 1; + + state.buffers.push_back(buf); + } else { + gst::debug!(CAT, imp: self, "None transcript item content"); } - - state.partial_index += 1; - - state.buffers.push_back(buf); } if !partial { @@ -434,12 +359,11 @@ impl Transcriber { } } - fn loop_fn(&self, receiver: &mut mpsc::Receiver) -> Result<(), gst::ErrorMessage> { + fn pad_loop_fn(&self, receiver: &mut mpsc::Receiver) -> Result<(), ()> { let mut events = { let mut events = vec![]; - let mut state = self.state.lock().unwrap(); - + let state = self.state.lock().unwrap(); if state.send_events { events.push( gst::event::StreamStart::builder("transcription") @@ -461,112 +385,71 @@ impl Transcriber { .seqnum(state.seqnum) .build(), ); - - state.send_events = false; } events }; - for event in events.drain(..) { - gst::info!(CAT, imp: self, "Sending {:?}", event); - self.srcpad.push_event(event); + if !events.is_empty() { + for event in events.drain(..) { + gst::info!(CAT, imp: self, "Sending {event:?}"); + self.srcpad.push_event(event); + } + + self.state.lock().unwrap().send_events = false; } let future = async move { - let msg = match receiver.next().await { - Some(msg) => msg, - /* Sender was closed */ - None => { - let _ = self.srcpad.pause_task(); - return Ok(()); - } + enum Winner { + TranscriptEvent(Option), + Timeout, + } + + let timer = tokio::time::sleep(GRANULARITY.into()).fuse(); + futures::pin_mut!(timer); + + let race_res = futures::select_biased! { + transcript_evt = receiver.next() => Winner::TranscriptEvent(transcript_evt), + _ = timer => Winner::Timeout, }; - match msg { - Message::Binary(buf) => { - let (_, pkt) = parse_packet(&buf).map_err(|err| { - gst::error!(CAT, imp: self, "Failed to parse packet: {}", err); - error_msg!( - gst::StreamError::Failed, - ["Failed to parse packet: {}", err] - ) - })?; + use Winner::*; + match race_res { + TranscriptEvent(Some(transcript_evt)) => { + if let Some(result) = transcript_evt + .transcript + .as_ref() + .and_then(|transcript| transcript.results()) + .and_then(|results| results.get(0)) + { + gst::trace!(CAT, imp: self, "Received: {result:?}"); - let payload = std::str::from_utf8(pkt.payload).unwrap(); - - if packet_is_exception(&pkt) { - let message: ExceptionMessage = - serde_json::from_str(payload).map_err(|err| { - gst::error!( - CAT, - imp: self, - "Unexpected exception message: {} ({})", - payload, - err - ); - error_msg!( - gst::StreamError::Failed, - ["Unexpected exception message: {} ({})", payload, err] - ) - })?; - gst::error!(CAT, imp: self, "AWS raised an error: {}", message.message); - - return Err(error_msg!( - gst::StreamError::Failed, - ["AWS raised an error: {}", message.message] - )); - } - - let transcript: Transcript = serde_json::from_str(payload).map_err(|err| { - error_msg!( - gst::StreamError::Failed, - ["Unexpected binary message: {} ({})", payload, err] - ) - })?; - - if let Some(result) = transcript.transcript.results.get(0) { - gst::trace!( - CAT, - imp: self, - "result: {}", - serde_json::to_string_pretty(&result).unwrap(), - ); - - if let Some(alternative) = result.alternatives.get(0) { - let mut state = self.state.lock().unwrap(); - - self.enqueue(&mut state, alternative, result.is_partial) + if let Some(alternative) = result + .alternatives + .as_ref() + .and_then(|alternatives| alternatives.get(0)) + { + if let Some(items) = alternative.items() { + let mut state = self.state.lock().unwrap(); + self.enqueue(&mut state, items, result.is_partial) + } } } - - Ok(()) } - - _ => Ok(()), + TranscriptEvent(None) => { + gst::info!(CAT, imp: self, "Transcript evt channel disconnected"); + // Something bad happened elsewhere, let the other side report. + return Err(()); + } + Timeout => (), } - }; - /* Wrap in a timeout so we can push gaps regularly */ - let future = async move { - match tokio::time::timeout(GRANULARITY.into(), future).await { - Err(_) => { - if !self.dequeue() { - gst::info!(CAT, imp: self, "Failed to push gap event, pausing"); - - let _ = self.srcpad.pause_task(); - } - Ok(()) - } - Ok(res) => { - if !self.dequeue() { - gst::info!(CAT, imp: self, "Failed to push gap event, pausing"); - - let _ = self.srcpad.pause_task(); - } - res - } + if !self.dequeue() { + gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing"); + let _ = self.srcpad.pause_task(); } + + Ok(()) }; let _enter = RUNTIME.enter(); @@ -574,26 +457,53 @@ impl Transcriber { } fn start_task(&self) -> Result<(), gst::LoggableError> { - let (sender, mut receiver) = mpsc::channel(1); + let mut state = self.state.lock().unwrap(); - { - let mut state = self.state.lock().unwrap(); - state.sender = Some(sender); - } + let (transcript_tx, mut transcript_rx) = mpsc::channel(1); let imp = self.ref_counted(); let res = self.srcpad.start_task(move || { - if let Err(err) = imp.loop_fn(&mut receiver) { - element_imp_error!(imp, gst::StreamError::Failed, ["Streaming failed: {}", err]); + if imp.pad_loop_fn(&mut transcript_rx).is_err() { + // Pad loop fn reported an unrecoverable error. + // FIXME we should probably stop the task as + // there's nothing we can do about it except restarting. let _ = imp.srcpad.pause_task(); } }); + if res.is_err() { - return Err(loggable_error!(CAT, "Failed to start pad task")); + state.transcript_tx = None; + return Err(gst::loggable_error!(CAT, "Failed to start pad task")); } + + state.transcript_tx = Some(transcript_tx); + Ok(()) } + fn stop_task(&self) { + let mut state = self.state.lock().unwrap(); + + let _ = self.srcpad.stop_task(); + + if let Some(ws_loop_handle) = state.ws_loop_handle.take() { + ws_loop_handle.abort(); + } + + state.transcript_tx = None; + state.buffer_tx = None; + } + + fn stop_ws_loop(&self) { + let mut state = self.state.lock().unwrap(); + + if let Some(ws_loop_handle) = state.ws_loop_handle.take() { + ws_loop_handle.abort(); + } + + state.buffer_tx = None; + } + fn src_activatemode( &self, _pad: &gst::Pad, @@ -603,24 +513,18 @@ impl Transcriber { if active { self.start_task()?; } else { - { - let mut state = self.state.lock().unwrap(); - state.sender = None; - } - - let _ = self.srcpad.stop_task(); + self.stop_task(); } Ok(()) } fn src_query(&self, pad: &gst::Pad, query: &mut gst::QueryRef) -> bool { - use gst::QueryViewMut; - - gst::log!(CAT, obj: pad, "Handling query {:?}", query); + gst::log!(CAT, obj: pad, "Handling query {query:?}"); + use gst::QueryViewMut::*; match query.view_mut() { - QueryViewMut::Latency(q) => { + Latency(q) => { let mut peer_query = gst::query::Latency::new(); let ret = self.sinkpad.peer_query(&mut peer_query); @@ -632,7 +536,7 @@ impl Transcriber { } ret } - QueryViewMut::Position(q) => { + Position(q) => { if q.format() == gst::Format::Time { let state = self.state.lock().unwrap(); q.set( @@ -650,44 +554,29 @@ impl Transcriber { } fn sink_event(&self, pad: &gst::Pad, event: gst::Event) -> bool { - use gst::EventView; - - gst::log!(CAT, obj: pad, "Handling event {:?}", event); + gst::log!(CAT, obj: pad, "Handling event {event:?}"); + use gst::EventView::*; match event.view() { - EventView::Eos(_) => match self.handle_buffer(pad, None) { - Err(err) => { - gst::error!(CAT, "Failed to send EOS to AWS: {}", err); - false - } - Ok(_) => true, - }, - EventView::FlushStart(_) => { + Eos(_) => { + self.stop_ws_loop(); + + true + } + FlushStart(_) => { gst::info!(CAT, imp: self, "Received flush start, disconnecting"); - let mut ret = gst::Pad::event_default(pad, Some(&*self.obj()), event); - - match self.srcpad.stop_task() { - Err(err) => { - gst::error!(CAT, imp: self, "Failed to stop srcpad task: {}", err); - - self.disconnect(); - - ret = false; - } - Ok(_) => { - self.disconnect(); - } - }; + let ret = gst::Pad::event_default(pad, Some(&*self.obj()), event); + self.stop_task(); ret } - EventView::FlushStop(_) => { + FlushStop(_) => { gst::info!(CAT, imp: self, "Received flush stop, restarting task"); if gst::Pad::event_default(pad, Some(&*self.obj()), event) { match self.start_task() { Err(err) => { - gst::error!(CAT, imp: self, "Failed to start srcpad task: {}", err); + gst::error!(CAT, imp: self, "Failed to start srcpad task: {err}"); false } Ok(_) => true, @@ -696,13 +585,13 @@ impl Transcriber { false } } - EventView::Segment(e) => { + Segment(e) => { let segment = match e.segment().clone().downcast::() { Err(segment) => { - element_imp_error!( + gst::element_imp_error!( self, gst::StreamError::Format, - ["Only Time segments supported, got {:?}", segment.format(),] + ["Only Time segments supported, got {:?}", segment.format()] ); return false; } @@ -716,355 +605,285 @@ impl Transcriber { true } - EventView::Tag(_) => true, - EventView::Caps(e) => { - gst::info!(CAT, "Received caps {:?}", e); + Tag(_) => true, + Caps(c) => { + gst::info!(CAT, "Received caps {c:?}"); true } - EventView::StreamStart(_) => true, + StreamStart(_) => true, _ => gst::Pad::event_default(pad, Some(&*self.obj()), event), } } - async fn sync_and_send( - &self, - buffer: Option, - ) -> Result { - let mut delay = None; - - { - let state = self.state.lock().unwrap(); - - if let Some(buffer) = &buffer { - let running_time = state.in_segment.to_running_time(buffer.pts()); - let now = self.obj().current_running_time(); - - delay = running_time.opt_checked_sub(now).ok().flatten(); - } - } - - if let Some(delay) = delay { - tokio::time::sleep(delay.into()).await; - } - - if let Some(ws_sink) = self.ws_sink.borrow_mut().as_mut() { - if let Some(buffer) = buffer { - let data = buffer.map_readable().unwrap(); - for chunk in data.chunks(8192) { - let packet = build_packet(chunk); - ws_sink.send(Message::Binary(packet)).await.map_err(|err| { - gst::error!(CAT, imp: self, "Failed sending packet: {}", err); - gst::FlowError::Error - })?; - } - } else { - // EOS - let packet = build_packet(&[]); - ws_sink.send(Message::Binary(packet)).await.map_err(|err| { - gst::error!(CAT, imp: self, "Failed sending packet: {}", err); - gst::FlowError::Error - })?; - } - } - - Ok(gst::FlowSuccess::Ok) - } - - fn handle_buffer( - &self, - _pad: &gst::Pad, - buffer: Option, - ) -> Result { - gst::log!(CAT, imp: self, "Handling {:?}", buffer); - - self.ensure_connection().map_err(|err| { - element_imp_error!( - self, - gst::StreamError::Failed, - ["Streaming failed: {}", err] - ); - gst::FlowError::Error - })?; - - let (future, abort_handle) = abortable(self.sync_and_send(buffer)); - - self.state.lock().unwrap().send_abort_handle = Some(abort_handle); - - let res = { - let _enter = RUNTIME.enter(); - futures::executor::block_on(future) - }; - - match res { - Err(_) => Err(gst::FlowError::Flushing), - Ok(res) => res, - } - } - fn sink_chain( &self, pad: &gst::Pad, buffer: gst::Buffer, ) -> Result { - self.handle_buffer(pad, Some(buffer)) + gst::log!(CAT, obj: pad, "Handling {buffer:?}"); + + self.ensure_connection().map_err(|err| { + gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]); + gst::FlowError::Error + })?; + + let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else { + gst::log!(CAT, obj: pad, "Flushing"); + return Err(gst::FlowError::Flushing); + }; + + futures::executor::block_on(buffer_tx.send(buffer)).map_err(|err| { + gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]); + gst::FlowError::Error + })?; + + self.state.lock().unwrap().buffer_tx = Some(buffer_tx); + + Ok(gst::FlowSuccess::Ok) } fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> { - let state = self.state.lock().unwrap(); - - if state.connected { - return Ok(()); + enum ClientStage { + Ready(aws_transcribe::Client), + NotReady { + access_key: Option, + secret_access_key: Option, + session_token: Option, + }, } - let in_caps = self.sinkpad.current_caps().unwrap(); - let s = in_caps.structure(0).unwrap(); - let sample_rate = s.get::("rate").unwrap(); + let (client_stage, transcription_settings, transcript_tx) = { + let mut state = self.state.lock().unwrap(); - let settings = self.settings.lock().unwrap(); + if let Some(ref ws_loop_handle) = state.ws_loop_handle { + if ws_loop_handle.is_finished() { + state.ws_loop_handle = None; - if settings.latency + settings.lateness <= 2 * GRANULARITY { - gst::error!( - CAT, - imp: self, - "latency + lateness must be greater than 200 milliseconds" - ); - return Err(error_msg!( - gst::LibraryError::Settings, - ["latency + lateness must be greater than 200 milliseconds"] - )); - } + const ERR: &str = "ws loop terminated unexpectedly"; + gst::error!(CAT, imp: self, "{ERR}"); + return Err(gst::error_msg!(gst::LibraryError::Failed, ["{ERR}"])); + } - gst::info!(CAT, imp: self, "Connecting .."); + return Ok(()); + } - let region = Region::new(DEFAULT_TRANSCRIBER_REGION); - let access_key = settings.access_key.as_ref(); - let secret_access_key = settings.secret_access_key.as_ref(); - let session_token = settings.session_token.clone(); + let transcript_tx = state + .transcript_tx + .take() + .expect("attempting to spawn the ws loop, but the srcpad task hasn't been started"); - let credentials = match (access_key, secret_access_key) { - (Some(key), Some(secret_key)) => { - gst::debug!( - CAT, - imp: self, - "Using provided access and secret access key" + let settings = self.settings.lock().unwrap(); + + if settings.latency + settings.lateness <= 2 * GRANULARITY { + const ERR: &str = "latency + lateness must be greater than 200 milliseconds"; + gst::error!(CAT, imp: self, "{ERR}"); + return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"])); + } + + let in_caps = self.sinkpad.current_caps().unwrap(); + let s = in_caps.structure(0).unwrap(); + let sample_rate = s.get::("rate").unwrap(); + + let transcription_settings = TranscriptionSettings::from(&settings, sample_rate); + + let client_stage = if let Some(client) = state.client.take() { + ClientStage::Ready(client) + } else { + ClientStage::NotReady { + access_key: settings.access_key.to_owned(), + secret_access_key: settings.secret_access_key.to_owned(), + session_token: settings.session_token.to_owned(), + } + }; + + (client_stage, transcription_settings, transcript_tx) + }; + + let client = match client_stage { + ClientStage::Ready(client) => client, + ClientStage::NotReady { + access_key, + secret_access_key, + session_token, + } => { + gst::info!(CAT, imp: self, "Connecting..."); + let _enter_guard = RUNTIME.enter(); + + let config_loader = match (access_key, secret_access_key) { + (Some(key), Some(secret_key)) => { + gst::debug!(CAT, imp: self, "Using settings credentials"); + aws_config::ConfigLoader::default().credentials_provider( + aws_transcribe::Credentials::new( + key, + secret_key, + session_token, + None, + "translate", + ), + ) + } + _ => { + gst::debug!(CAT, imp: self, "Attempting to get credentials from env..."); + aws_config::from_env() + } + }; + + let config_loader = config_loader.region( + aws_config::meta::region::RegionProviderChain::default_provider() + .or_else(DEFAULT_TRANSCRIBER_REGION), ); - Ok(Credentials::new( - key.clone(), - secret_key.clone(), - session_token, - None, - "transcribe", - )) - } - _ => { - gst::debug!(CAT, imp: self, "Using default AWS credentials"); - let cred_future = async { - let cred = DefaultCredentialsChain::builder() - .region(region.clone()) - .build() - .await; - cred.provide_credentials().await - }; + let config = futures::executor::block_on(config_loader.load()); + gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap()); - RUNTIME.block_on(cred_future) - } - }; - - if let Err(e) = credentials { - return Err(error_msg!( - gst::LibraryError::Settings, - ["Failed to retrieve credentials with error {}", e] - )); - } - - let current_time = Utc::now(); - - let mut query_params = String::from("/stream-transcription-websocket?"); - - let language_code = settings - .language_code - .as_ref() - .expect("Language code is required"); - - query_params.push_str( - format!( - "language-code={}&media-encoding=pcm&sample-rate={}", - language_code, - &sample_rate.to_string(), - ) - .as_str(), - ); - - if let Some(ref vocabulary) = settings.vocabulary { - query_params.push_str(format!("&vocabulary-name={vocabulary}").as_str()); - } - - if let Some(ref vocabulary_filter) = settings.vocabulary_filter { - query_params.push_str(format!("&vocabulary-filter-name={vocabulary_filter}").as_str()); - - query_params.push_str( - format!( - "&vocabulary-filter-method={}", - match settings.vocabulary_filter_method { - AwsTranscriberVocabularyFilterMethod::Mask => "mask", - AwsTranscriberVocabularyFilterMethod::Remove => "remove", - AwsTranscriberVocabularyFilterMethod::Tag => "tag", - } - ) - .as_str(), - ); - } - - if let Some(ref session_id) = settings.session_id { - gst::debug!(CAT, imp: self, "Using session ID: {}", session_id); - query_params.push_str(format!("&session-id={session_id}").as_str()); - } - - query_params.push_str("&enable-partial-results-stabilization=true"); - - query_params.push_str( - format!( - "&partial-results-stability={}", - match settings.results_stability { - AwsTranscriberResultStability::High => "high", - AwsTranscriberResultStability::Medium => "medium", - AwsTranscriberResultStability::Low => "low", - } - ) - .as_str(), - ); - - drop(settings); - drop(state); - - let signer = signer::SigV4Signer::new(); - let mut operation_config = OperationSigningConfig::default_config(); - operation_config.signature_type = HttpSignatureType::HttpRequestQueryParams; - operation_config.expires_in = Some(Duration::from_secs(5 * 60)); // See commit a3db85d. - - let request_config = RequestConfig { - request_ts: SystemTime::from(current_time), - region: &SigningRegion::from(region.clone()), - service: &SigningService::from_static("transcribe"), - payload_override: None, - }; - let transcribe_uri = Uri::builder() - .scheme("https") - .authority(format!("transcribestreaming.{region}.amazonaws.com:8443").as_str()) - .path_and_query(query_params.clone()) - .build() - .map_err(|err| { - gst::error!(CAT, imp: self, "Failed to build HTTP request URI: {}", err); - error_msg!( - gst::CoreError::Failed, - ["Failed to build HTTP request URI: {}", err] - ) - })?; - let mut request = http::Request::builder() - .uri(transcribe_uri) - .body(SdkBody::empty()) - .expect("Failed to build valid request"); - let _signature = signer - .sign( - &operation_config, - &request_config, - &credentials.unwrap(), - &mut request, - ) - .map_err(|err| { - gst::error!(CAT, imp: self, "Failed to sign HTTP request: {}", err); - error_msg!( - gst::CoreError::Failed, - ["Failed to sign HTTP request: {}", err] - ) - })?; - let url = request.uri().to_string(); - - let (ws, _) = { - let _enter = RUNTIME.enter(); - futures::executor::block_on(connect_async(format!("wss{}", &url[5..]))).map_err( - |err| { - gst::error!(CAT, imp: self, "Failed to connect: {}", err); - error_msg!(gst::CoreError::Failed, ["Failed to connect: {}", err]) - }, - )? - }; - - let (ws_sink, mut ws_stream) = ws.split(); - - *self.ws_sink.borrow_mut() = Some(Box::pin(ws_sink)); - - let imp_weak = self.downgrade(); - let future = async move { - while let Some(transcribe) = imp_weak.upgrade() { - let msg = match ws_stream.next().await { - Some(msg) => msg, - None => { - let mut state = transcribe.state.lock().unwrap(); - state.send_eos = true; - break; - } - }; - - let msg = match msg { - Ok(msg) => msg, - Err(err) => { - gst::error!(CAT, imp: transcribe, "Failed to receive data: {}", err); - element_imp_error!( - transcribe, - gst::StreamError::Failed, - ["Streaming failed: {}", err] - ); - break; - } - }; - - let mut sender = transcribe.state.lock().unwrap().sender.clone(); - - if let Some(sender) = sender.as_mut() { - if sender.send(msg).await.is_err() { - break; - } - } + aws_transcribe::Client::new(&config) } }; let mut state = self.state.lock().unwrap(); - let (future, abort_handle) = abortable(future); + let (buffer_tx, buffer_rx) = mpsc::channel(1); + let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut( + client, + transcription_settings, + buffer_rx, + transcript_tx, + )); - state.recv_abort_handle = Some(abort_handle); - - RUNTIME.spawn(future); - - state.connected = true; - - gst::info!(CAT, imp: self, "Connected"); + state.ws_loop_handle = Some(ws_loop_handle); + state.buffer_tx = Some(buffer_tx); Ok(()) } + fn build_ws_loop_fut( + &self, + client: aws_transcribe::Client, + settings: TranscriptionSettings, + buffer_rx: mpsc::Receiver, + transcript_tx: mpsc::Sender, + ) -> impl Future> { + let imp_weak = self.downgrade(); + async move { + use gst::glib::subclass::ObjectImplWeakRef; + + // Guard that restores client & transcript_tx when the ws loop is done + struct Guard { + imp_weak: ObjectImplWeakRef, + client: Option, + transcript_tx: Option>, + } + + impl Guard { + fn client(&self) -> &aws_transcribe::Client { + self.client.as_ref().unwrap() + } + + fn transcript_tx(&mut self) -> &mut mpsc::Sender { + self.transcript_tx.as_mut().unwrap() + } + } + + impl Drop for Guard { + fn drop(&mut self) { + if let Some(imp) = self.imp_weak.upgrade() { + let mut state = imp.state.lock().unwrap(); + state.client = self.client.take(); + state.transcript_tx = self.transcript_tx.take(); + } + } + } + + let mut guard = Guard { + imp_weak: imp_weak.clone(), + client: Some(client), + transcript_tx: Some(transcript_tx), + }; + + // Stream the incoming buffers chunked + let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| { + async_stream::stream! { + let data = buffer.map_readable().unwrap(); + use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; + for chunk in data.chunks(8192) { + yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); + } + } + }); + + let mut transcribe_builder = guard + .client() + .start_stream_transcription() + .language_code(settings.lang_code) + .media_sample_rate_hertz(settings.sample_rate) + .media_encoding(model::MediaEncoding::Pcm) + .enable_partial_results_stabilization(true) + .partial_results_stability(settings.results_stability) + .set_vocabulary_name(settings.vocabulary) + .set_session_id(settings.session_id); + + if let Some(vocabulary_filter) = settings.vocabulary_filter { + transcribe_builder = transcribe_builder + .vocabulary_filter_name(vocabulary_filter) + .vocabulary_filter_method(settings.vocabulary_filter_method); + } + + let mut output = transcribe_builder + .audio_stream(chunk_stream.into()) + .send() + .await + .map_err(|err| { + let err = format!("Transcribe ws init error: {err}"); + if let Some(imp) = imp_weak.upgrade() { + gst::error!(CAT, imp: imp, "{err}"); + } + gst::error_msg!(gst::LibraryError::Init, ["{err}"]) + })?; + + while let Some(event) = output + .transcript_result_stream + .recv() + .await + .map_err(|err| { + let err = format!("Transcribe ws stream error: {err}"); + if let Some(imp) = imp_weak.upgrade() { + gst::error!(CAT, imp: imp, "{err}"); + } + gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) + })? + { + if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { + if guard.transcript_tx().send(transcript_evt).await.is_err() { + if let Some(imp) = imp_weak.upgrade() { + gst::debug!(CAT, imp: imp, "Terminated transcript_evt channel"); + } + break; + } + } else if let Some(imp) = imp_weak.upgrade() { + gst::warning!( + CAT, + imp: imp, + "Transcribe ws returned unknown event: consider upgrading the SDK" + ) + } else { + // imp has left the building + break; + } + } + + if let Some(imp) = imp_weak.upgrade() { + gst::debug!(CAT, imp: imp, "Exiting ws loop"); + } + + Ok(()) + } + } + fn disconnect(&self) { let mut state = self.state.lock().unwrap(); - gst::info!(CAT, imp: self, "Unpreparing"); - - if let Some(abort_handle) = state.recv_abort_handle.take() { - abort_handle.abort(); - } - - if let Some(abort_handle) = state.send_abort_handle.take() { - abort_handle.abort(); - } - + self.stop_task(); *state = State::default(); - - gst::info!( - CAT, - imp: self, - "Unprepared, connected: {}!", - state.connected - ); + gst::info!(CAT, imp: self, "Unprepared"); } } @@ -1098,7 +917,12 @@ impl ObjectSubclass for Transcriber { .activatemode_function(|pad, parent, mode, active| { Transcriber::catch_panic_pad_function( parent, - || Err(loggable_error!(CAT, "Panic activating src pad with mode")), + || { + Err(gst::loggable_error!( + CAT, + "Panic activating src pad with mode" + )) + }, |transcriber| transcriber.src_activatemode(pad, mode, active), ) }) @@ -1119,7 +943,6 @@ impl ObjectSubclass for Transcriber { sinkpad, settings, state: Mutex::new(State::default()), - ws_sink: AtomicRefCell::new(None), } } } @@ -1133,7 +956,7 @@ impl ObjectImpl for Transcriber { .blurb("The Language of the Stream, see \ \ for an up to date list of allowed languages") - .default_value(Some("en-US")) + .default_value(Some(DEFAULT_LANGUAGE_CODE)) .mutable_ready() .build(), glib::ParamSpecUInt::builder("latency") @@ -1325,7 +1148,7 @@ impl ElementImpl for Transcriber { "Transcriber", "Audio/Text/Filter", "Speech to Text filter, using AWS transcribe", - "Jordan Petridis , Mathieu Duponchelle ", + "Jordan Petridis , Mathieu Duponchelle , François Laignel ", ) }); @@ -1368,7 +1191,7 @@ impl ElementImpl for Transcriber { &self, transition: gst::StateChange, ) -> Result { - gst::info!(CAT, imp: self, "Changing state {:?}", transition); + gst::info!(CAT, imp: self, "Changing state {transition:?}"); let mut success = self.parent_change_state(transition)?; diff --git a/net/aws/src/transcriber/mod.rs b/net/aws/src/transcriber/mod.rs index db04c9ac..69ac6059 100644 --- a/net/aws/src/transcriber/mod.rs +++ b/net/aws/src/transcriber/mod.rs @@ -10,7 +10,8 @@ use gst::glib; use gst::prelude::*; mod imp; -mod packet; + +use aws_sdk_transcribestreaming::model::{PartialResultsStability, VocabularyFilterMethod}; #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)] #[repr(u32)] @@ -31,6 +32,17 @@ pub enum AwsTranscriberResultStability { Low = 2, } +impl From for PartialResultsStability { + fn from(val: AwsTranscriberResultStability) -> Self { + use AwsTranscriberResultStability::*; + match val { + High => PartialResultsStability::High, + Medium => PartialResultsStability::Medium, + Low => PartialResultsStability::Low, + } + } +} + #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)] #[repr(u32)] #[enum_type(name = "GstAwsTranscriberVocabularyFilterMethod")] @@ -44,6 +56,17 @@ pub enum AwsTranscriberVocabularyFilterMethod { Tag = 2, } +impl From for VocabularyFilterMethod { + fn from(val: AwsTranscriberVocabularyFilterMethod) -> Self { + use AwsTranscriberVocabularyFilterMethod::*; + match val { + Mask => VocabularyFilterMethod::Mask, + Remove => VocabularyFilterMethod::Remove, + Tag => VocabularyFilterMethod::Tag, + } + } +} + glib::wrapper! { pub struct Transcriber(ObjectSubclass) @extends gst::Element, gst::Object; } diff --git a/net/aws/src/transcriber/packet/mod.rs b/net/aws/src/transcriber/packet/mod.rs deleted file mode 100644 index d11a054c..00000000 --- a/net/aws/src/transcriber/packet/mod.rs +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (C) 2020 Jordan Petridis -// -// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. -// If a copy of the MPL was not distributed with this file, You can obtain one at -// . -// -// SPDX-License-Identifier: MPL-2.0 - -use byteorder::{BigEndian, WriteBytesExt}; -use nom::{ - self, bytes::complete::take, combinator::map_res, multi::many0, number::complete::be_u16, - number::complete::be_u32, number::complete::be_u8, sequence::tuple, IResult, -}; -use std::borrow::Cow; -use std::io::{self, Write}; - -const CRC: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); - -#[derive(Debug)] -struct Prelude { - total_bytes: u32, - header_bytes: u32, - #[allow(dead_code)] - prelude_crc: u32, -} - -#[derive(Debug)] -pub struct Header { - pub name: Cow<'static, str>, - pub value_type: u8, - pub value: Cow<'static, str>, -} - -#[derive(Debug)] -pub struct Packet<'a> { - #[allow(dead_code)] - prelude: Prelude, - headers: Vec
, - pub payload: &'a [u8], - #[allow(dead_code)] - msg_crc: u32, -} - -fn write_header(w: &mut W, header: &Header) -> Result<(), io::Error> { - w.write_u8(header.name.len() as u8)?; - w.write_all(header.name.as_bytes())?; - w.write_u8(header.value_type)?; - w.write_u16::(header.value.len() as u16)?; - w.write_all(header.value.as_bytes())?; - Ok(()) -} - -fn write_headers(w: &mut W, headers: &[Header]) -> Result<(), io::Error> { - for header in headers { - write_header(w, header)?; - } - Ok(()) -} - -pub fn encode_packet(payload: &[u8], headers: &[Header]) -> Result, io::Error> { - let mut res = Vec::with_capacity(1024); - - // Total length - res.write_u32::(0)?; - // Header length - res.write_u32::(0)?; - // Prelude CRC32 placeholder - res.write_u32::(0)?; - - // Write all headers - write_headers(&mut res, headers)?; - - // Rewrite header length - let header_length = res.len() - 12; - (&mut res[4..8]).write_u32::(header_length as u32)?; - - // Write payload - res.write_all(payload)?; - - // Rewrite total length - let total_length = res.len() + 4; - (&mut res[0..4]).write_u32::(total_length as u32)?; - - // Rewrite the prelude crc since we replaced the lengths - let prelude_crc = CRC.checksum(&res[0..8]); - (&mut res[8..12]).write_u32::(prelude_crc)?; - - // Message CRC - let message_crc = CRC.checksum(&res); - res.write_u32::(message_crc)?; - - Ok(res) -} - -fn parse_prelude(input: &[u8]) -> IResult<&[u8], Prelude> { - map_res( - tuple((be_u32, be_u32, be_u32)), - |(total_bytes, header_bytes, prelude_crc)| { - let sum = CRC.checksum(&input[0..8]); - if prelude_crc != sum { - return Err(nom::Err::Error(( - "Prelude CRC doesn't match", - nom::error::ErrorKind::MapRes, - ))); - } - - Ok(Prelude { - total_bytes, - header_bytes, - prelude_crc, - }) - }, - )(input) -} - -fn parse_header(input: &[u8]) -> IResult<&[u8], Header> { - let (input, header_length) = be_u8(input)?; - let (input, name) = map_res(take(header_length), std::str::from_utf8)(input)?; - let (input, value_type) = be_u8(input)?; - let (input, value_length) = be_u16(input)?; - let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?; - - let header = Header { - name: name.to_string().into(), - value_type, - value: value.to_string().into(), - }; - - Ok((input, header)) -} - -pub fn packet_is_exception(packet: &Packet) -> bool { - for header in &packet.headers { - if header.name == ":message-type" && header.value == "exception" { - return true; - } - } - - false -} - -pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> { - let (remainder, prelude) = parse_prelude(input)?; - - // Check the crc of the whole input - let sum = CRC.checksum(&input[..input.len() - 4]); - let (_, msg_crc) = be_u32(&input[input.len() - 4..])?; - - if msg_crc != sum { - return Err(nom::Err::Error(nom::error::Error::new( - b"Prelude CRC doesn't match", - nom::error::ErrorKind::MapRes, - ))); - } - - let (remainder, header_input) = take(prelude.header_bytes)(remainder)?; - let (_, headers) = many0(parse_header)(header_input)?; - - let payload_length = prelude.total_bytes - prelude.header_bytes - 4 - 12; - let (remainder, payload) = take(payload_length)(remainder)?; - - // only the message_crc we check before should be remaining now - assert_eq!(remainder.len(), 4); - - Ok(( - input, - Packet { - prelude, - headers, - payload, - msg_crc, - }, - )) -}