diff --git a/audio/transcribe/Cargo.toml b/audio/transcribe/Cargo.toml index ef14a588..9e5667c5 100644 --- a/audio/transcribe/Cargo.toml +++ b/audio/transcribe/Cargo.toml @@ -3,11 +3,15 @@ name = "gst-plugin-transcribe" version = "0.1.0" authors = ["Jordan Petridis "] edition = "2018" +description = "AWS Transcribe plugin" +repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs" # FIXME: licence [dependencies] glib = { git = "https://github.com/gtk-rs/glib" } -gst = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"], package = "gstreamer" } +# FIXME: +# gst = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"], package = "gstreamer" } +gstreamer = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] } gst_base = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"], package ="gstreamer-base" } rusoto_core = "0.43.0-beta.1" rusoto_credential = "0.43.0-beta.1" @@ -20,6 +24,10 @@ async-tungstenite = { version = "0.4", features = ["tokio", "tokio-runtime", "to nom = "5.1.1" crc = "1.8.1" byteorder = "1.3.4" +lazy_static = "1.4.0" +serde = "1" +serde_derive = "1" +serde_json = "1" [lib] name = "gsttranscribe" diff --git a/audio/transcribe/src/aws_transcribe_parse.rs b/audio/transcribe/src/aws_transcribe_parse.rs new file mode 100644 index 00000000..7a7f5815 --- /dev/null +++ b/audio/transcribe/src/aws_transcribe_parse.rs @@ -0,0 +1,1031 @@ +#![allow(unused)] + +use glib; +use glib::prelude::*; +use glib::subclass; +use glib::subclass::prelude::*; +use gst; +use gst::prelude::*; +use gst::subclass::prelude::*; + +use std::convert::TryInto; +use std::default::Default; +use std::env; + +use rusoto_core::Region; +use rusoto_credential; +use rusoto_credential::{EnvironmentProvider, ProvideAwsCredentials}; +use rusoto_transcribe; + +use rusoto_signature::signature::SignedRequest; +use rusoto_signature::signature::SignedRequestPayload; +use rusoto_signature::stream::ByteStream; +use rusoto_transcribe::Media; +use rusoto_transcribe::Settings as TranscriptionSettings; +use rusoto_transcribe::StartTranscriptionJobRequest; +use rusoto_transcribe::Transcribe; + +use async_tungstenite::tungstenite::error::Error as WsError; +use async_tungstenite::{tokio::connect_async, tungstenite::Message}; +use futures::channel::mpsc; +use futures::future::{abortable, AbortHandle}; +use futures::io::{AsyncReadExt, Cursor}; +use futures::prelude::*; +use futures::stream::SplitSink; +use tokio::runtime; + +use crc::crc32; + +use std::boxed::Box; +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::{Mutex, MutexGuard}; +use std::time::Duration; + +use lazy_static; + +use crate::packet::*; + +use serde_derive::{Deserialize, Serialize}; + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct TranscriptItem { + content: String, + end_time: f32, + start_time: f32, + #[serde(rename = "Type")] + type_: String, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct TranscriptAlternative { + items: Vec, + transcript: String, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct TranscriptResult { + alternatives: Vec, + end_time: f32, + start_time: f32, + is_partial: bool, + result_id: String, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct TranscriptTranscript { + results: Vec, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct Transcript { + transcript: TranscriptTranscript, +} + +lazy_static! { + static ref CAT: gst::DebugCategory = { + gst::DebugCategory::new( + "awstranscribe", + gst::DebugColorFlags::empty(), + Some("AWS Transcribe element"), + ) + }; + static ref RUNTIME: runtime::Runtime = runtime::Builder::new() + .threaded_scheduler() + .enable_all() + .core_threads(1) + .build() + .unwrap(); +} + +const DEFAULT_LATENCY_MS: u32 = 30000; +const GRANULARITY_MS: u32 = 100; + +static PROPERTIES: [subclass::Property; 2] = [ + subclass::Property("language-code", |name| { + glib::ParamSpec::string( + name, + "Language-Code", + "The Language of the Stream, see \ + \ + for an up to date list of allowed languages", + Some("en-US"), + glib::ParamFlags::READWRITE, + ) + }), + subclass::Property("latency", |name| { + glib::ParamSpec::uint( + name, + "latency in ms", + "Amount of milliseconds to allow AWS transcribe", + GRANULARITY_MS, + std::u32::MAX, + DEFAULT_LATENCY_MS, + glib::ParamFlags::READWRITE, + ) + }), +]; + +#[derive(Debug, Clone)] +struct Settings { + latency_ms: u32, + language_code: Option, +} + +impl Default for Settings { + fn default() -> Self { + Self { + latency_ms: DEFAULT_LATENCY_MS, + language_code: Some("en-US".to_string()), + } + } +} + +struct State { + connected: bool, + sender: Option>, + recv_abort_handle: Option, + send_abort_handle: Option, + in_segment: gst::FormattedSegment, + out_segment: gst::FormattedSegment, + seqnum: gst::Seqnum, + buffers: VecDeque, + send_eos: bool, + discont: bool, +} + +impl Default for State { + fn default() -> Self { + Self { + connected: false, + sender: None, + recv_abort_handle: None, + send_abort_handle: None, + in_segment: gst::FormattedSegment::new(), + out_segment: gst::FormattedSegment::new(), + seqnum: gst::Seqnum::next(), + buffers: VecDeque::new(), + send_eos: false, + discont: true, + } + } +} + +struct Transcriber { + srcpad: gst::Pad, + sinkpad: gst::Pad, + settings: Mutex, + state: Mutex, + ws_sink: Mutex + Send>>>>, +} + +fn build_packet(payload: &[u8]) -> Vec { + let headers = [ + Header { + name: String::from(":event-type"), + value: String::from("AudioEvent"), + value_type: 7, + }, + Header { + name: String::from(":content-type"), + value: String::from("application/octet-stream"), + value_type: 7, + }, + Header { + name: String::from(":message-type"), + value: String::from("event"), + value_type: 7, + }, + ]; + + encode_packet(payload, &headers).expect("foobar") +} + +fn get_current_running_time(element: &gst::Element) -> gst::ClockTime { + if let Some(clock) = element.get_clock() { + if clock.get_time() > element.get_base_time() { + clock.get_time() - element.get_base_time() + } else { + 0.into() + } + } else { + gst::CLOCK_TIME_NONE + } +} + +impl Transcriber { + fn set_pad_functions(sinkpad: &gst::Pad, srcpad: &gst::Pad) { + sinkpad.set_chain_function(|pad, parent, buffer| { + Transcriber::catch_panic_pad_function( + parent, + || Err(gst::FlowError::Error), + |transcriber, element| transcriber.sink_chain(pad, element, buffer), + ) + }); + sinkpad.set_event_function(|pad, parent, event| { + Transcriber::catch_panic_pad_function( + parent, + || false, + |transcriber, element| transcriber.sink_event(pad, element, event), + ) + }); + + srcpad.set_activatemode_function(|pad, parent, mode, active| { + Transcriber::catch_panic_pad_function( + parent, + || { + Err(gst_loggable_error!( + CAT, + "Panic activating src pad with mode" + )) + }, + |transcriber, element| transcriber.src_activatemode(pad, element, mode, active), + ) + }); + srcpad.set_query_function(|pad, parent, query| { + Transcriber::catch_panic_pad_function( + parent, + || false, + |transcriber, element| transcriber.src_query(pad, element, query), + ) + }); + } + + fn push_gap(&self, element: &gst::Element) -> bool { + /* First, check our pending buffers */ + let mut items: Vec = vec![]; + + let (latency, now, mut last_position, send_eos, seqnum) = { + let mut state = self.state.lock().unwrap(); + let send_eos = state.send_eos && state.buffers.len() == 0; + + let latency: gst::ClockTime = ((self.settings.lock().unwrap().latency_ms as u64 + - GRANULARITY_MS as u64) + * 1000000) + .into(); + let now = get_current_running_time(element); + + loop { + if let Some(buf) = state.buffers.front() { + if now - buf.get_pts() > latency { + /* Safe unwrap, we know we have an item */ + let buf = state.buffers.pop_front().unwrap(); + items.push(buf) + } else { + break; + } + } + + break; + } + + ( + latency, + now, + state.out_segment.get_position(), + send_eos, + state.seqnum, + ) + }; + + /* We're EOS, we can pause and exit early */ + if send_eos { + self.srcpad.pause_task(); + return self + .srcpad + .push_event(gst::Event::new_eos().seqnum(seqnum).build()); + } + + for mut buf in items.drain(..) { + if buf.get_pts() > last_position { + let gap_event = gst::Event::new_gap(last_position, buf.get_pts() - last_position) + .seqnum(seqnum) + .build(); + gst_debug!( + CAT, + "Pushing gap: {} -> {}", + last_position, + buf.get_pts() + ); + if !self.srcpad.push_event(gap_event) { + return false; + } + } + last_position = buf.get_pts() + buf.get_duration(); + { + let mut buf = buf.get_mut().unwrap(); + buf.set_pts(buf.get_pts()); + } + gst_debug!( + CAT, + "Pushing buffer: {} -> {}", + buf.get_pts(), + buf.get_pts() + buf.get_duration() + ); + if !self.srcpad.push(buf).is_ok() { + return false; + } + } + + /* next, push a gap if we're lagging behind the target position */ + + if now - last_position > latency { + let duration = now - last_position - latency; + + let gap_event = gst::Event::new_gap(last_position, duration) + .seqnum(seqnum) + .build(); + gst_debug!( + CAT, + "Pushing gap: {} -> {}", + last_position, + last_position + duration + ); + last_position += duration; + if !self.srcpad.push_event(gap_event) { + return false; + } + } + + self.state + .lock() + .unwrap() + .out_segment + .set_position(last_position); + + true + } + + fn loop_fn( + &self, + element: &gst::Element, + receiver: &mut mpsc::Receiver, + ) -> Result<(), gst::ErrorMessage> { + let future = async move { + let msg = match receiver.next().await { + Some(msg) => msg, + /* Sender was closed */ + None => { + self.srcpad.pause_task(); + return Ok(()); + } + }; + + match msg { + Message::Binary(buf) => { + let (data, pkt) = parse_packet(&buf).unwrap(); + let payload = std::str::from_utf8(pkt.payload).unwrap(); + let transcript: Transcript = serde_json::from_str(&payload).map_err(|err| { + gst_error_msg!( + gst::StreamError::Failed, + ["Unexpected binary message: {} ({})", payload, err] + ) + })?; + + if transcript.transcript.results.len() > 0 { + let result = &transcript.transcript.results[0]; + if !result.is_partial && result.alternatives.len() > 0 { + let alternative = &result.alternatives[0]; + gst_info!(CAT, obj: element, "Transcript: {}", alternative.transcript); + + let mut start_time: gst::ClockTime = + ((result.start_time as f64 * 1000000000.0) as u64).into(); + let end_time: gst::ClockTime = + ((result.end_time as f64 * 1000000000.0) as u64).into(); + + let mut state = self.state.lock().unwrap(); + let position = state.out_segment.get_position(); + + if end_time < position { + gst_warning!(CAT, obj: element, + "Received transcript is too late by {:?}, dropping, consider increasing the latency", + position - start_time); + } else { + if start_time < position { + gst_warning!(CAT, obj: element, + "Received transcript is too late by {:?}, clipping, consider increasing the latency", + position - start_time); + start_time = position; + } + + let mut buf = gst::Buffer::from_mut_slice( + alternative.transcript.as_bytes().to_vec(), + ); + + { + let mut buf = buf.get_mut().unwrap(); + + if state.discont { + buf.set_flags(gst::BufferFlags::DISCONT); + state.discont = false; + } + + buf.set_pts(start_time); + buf.set_duration(end_time - start_time); + } + + gst_debug!(CAT, obj: element, "Adding pending buffer: {:?}", buf); + + state.buffers.push_back(buf); + } + } + } + Ok(()) + } + + _ => Ok(()), + } + }; + + /* Wrap in a timeout so we can push gaps regularly */ + let future = async move { + match tokio::time::timeout(Duration::from_millis(GRANULARITY_MS.into()), future).await { + Err(_) => { + if !self.push_gap(element) { + gst_info!(CAT, obj: element, "Failed to push gap event, pausing"); + self.srcpad.pause_task(); + } + Ok(()) + } + Ok(res) => { + if !self.push_gap(element) { + gst_info!(CAT, obj: element, "Failed to push gap event, pausing"); + self.srcpad.pause_task(); + } + res + } + } + }; + + RUNTIME.enter(|| futures::executor::block_on(future)) + } + + fn start_task(&self, element: &gst::Element) -> Result<(), gst::LoggableError> { + let element_weak = element.downgrade(); + let pad_weak = self.srcpad.downgrade(); + let (sender, mut receiver) = mpsc::channel(1); + + { + let mut state = self.state.lock().unwrap(); + state.sender = Some(sender); + } + + let res = self.srcpad.start_task(move || { + let element = match element_weak.upgrade() { + Some(element) => element, + None => { + if let Some(pad) = pad_weak.upgrade() { + pad.pause_task().unwrap(); + } + return; + } + }; + + let transcribe = Self::from_instance(&element); + transcribe.loop_fn(&element, &mut receiver).map_err(|err| { + gst_element_error!( + &element, + gst::StreamError::Failed, + ["Streaming failed: {}", err] + ); + transcribe.srcpad.pause_task().unwrap(); + }); + }); + if res.is_err() { + return Err(gst_loggable_error!(CAT, "Failed to start pad task")); + } + Ok(()) + } + + fn src_activatemode( + &self, + _pad: &gst::Pad, + element: &gst::Element, + mode: gst::PadMode, + active: bool, + ) -> Result<(), gst::LoggableError> { + if active { + self.start_task(element)?; + } else { + { + let mut state = self.state.lock().unwrap(); + state.sender = None; + } + + let _ = self.srcpad.stop_task(); + } + + Ok(()) + } + + fn src_query(&self, pad: &gst::Pad, element: &gst::Element, query: &mut gst::QueryRef) -> bool { + use gst::QueryView; + + gst_log!(CAT, obj: pad, "Handling query {:?}", query); + + match query.view_mut() { + QueryView::Latency(ref mut q) => { + let mut peer_query = gst::query::Query::new_latency(); + + let ret = self.sinkpad.peer_query(&mut peer_query); + + if ret { + let (_, min, _) = peer_query.get_result(); + let our_latency: gst::ClockTime = + (self.settings.lock().unwrap().latency_ms as u64 * 1000000).into(); + q.set(true, our_latency + min, gst::CLOCK_TIME_NONE); + } + ret + } + QueryView::Position(ref mut q) => { + if q.get_format() == gst::Format::Time { + let state = self.state.lock().unwrap(); + q.set( + state + .in_segment + .to_stream_time(state.in_segment.get_position()), + ); + true + } else { + false + } + } + _ => pad.query_default(Some(element), query), + } + } + + fn sink_event(&self, pad: &gst::Pad, element: &gst::Element, event: gst::Event) -> bool { + use gst::EventView; + + gst_debug!(CAT, obj: pad, "Handling event {:?}", event); + + match event.view() { + EventView::Eos(_) => match self.handle_buffer(pad, element, None) { + Err(err) => { + gst_error!(CAT, "Failed to send EOS to AWS: {}", err); + false + } + Ok(_) => true, + }, + EventView::FlushStart(e) => { + gst_info!(CAT, obj: element, "Received flush start, disconnecting"); + self.disconnect(element); + let ret = pad.event_default(Some(element), event); + self.srcpad.stop_task(); + ret + } + EventView::FlushStop(e) => { + gst_info!(CAT, obj: element, "Received flush stop, restarting task"); + let ret = pad.event_default(Some(element), event); + self.start_task(element); + ret + } + EventView::Segment(e) => { + let segment = match e.get_segment().clone().downcast::() { + Err(segment) => { + gst_element_error!( + element, + gst::StreamError::Format, + [ + "Only Time segments supported, got {:?}", + segment.get_format(), + ] + ); + return false; + } + Ok(segment) => segment, + }; + + let event = { + let mut state = self.state.lock().unwrap(); + + state.out_segment.set_time(segment.get_time()); + state + .out_segment + .set_position(gst::ClockTime::from_nseconds(0)); + + state.in_segment = segment; + state.seqnum = e.get_seqnum(); + gst::Event::new_segment(&state.out_segment) + .seqnum(state.seqnum) + .build() + }; + + gst_info!(CAT, "Sending our own segment: {:?}", event); + + pad.event_default(Some(element), event) + } + EventView::Tag(_) => true, + EventView::Caps(e) => { + gst_info!(CAT, "Received caps {:?}", e); + + let caps = gst::Caps::builder("text/x-raw") + .field("format", &"utf8") + .build(); + let seqnum = self.state.lock().unwrap().seqnum; + self.srcpad + .push_event(gst::Event::new_caps(&caps).seqnum(seqnum).build()) + } + _ => pad.event_default(Some(element), event), + } + } + + async fn sync_and_send( + &self, + element: &gst::Element, + buffer: Option, + ) -> Result { + let mut delay = None; + + { + let state = self.state.lock().unwrap(); + + if let Some(buffer) = &buffer { + let running_time = state.in_segment.to_running_time(buffer.get_pts()); + let now = get_current_running_time(&element); + + if now.is_some() && now < running_time { + delay = Some(running_time - now); + } + } + } + + if let Some(delay) = delay { + tokio::time::delay_for(Duration::from_nanos(delay.nseconds().unwrap())).await; + } + + if let Some(ws_sink) = self.ws_sink.lock().unwrap().as_mut() { + if let Some(buffer) = buffer { + let data = buffer.map_readable().unwrap(); + for chunk in data.chunks(8192) { + let packet = build_packet(chunk); + ws_sink.send(Message::Binary(packet)).await.map_err(|err| { + gst_error!(CAT, obj: element, "Failed sending packet: {}", err); + gst::FlowError::Error + })?; + } + } else { + // EOS + let packet = build_packet(&[]); + ws_sink.send(Message::Binary(packet)).await.map_err(|err| { + gst_error!(CAT, obj: element, "Failed sending packet: {}", err); + gst::FlowError::Error + })?; + } + } + + Ok(gst::FlowSuccess::Ok) + } + + fn handle_buffer( + &self, + pad: &gst::Pad, + element: &gst::Element, + buffer: Option, + ) -> Result { + gst_debug!(CAT, obj: element, "Handling {:?}", buffer); + + self.ensure_connection(element).map_err(|err| { + gst_element_error!( + &element, + gst::StreamError::Failed, + ["Streaming failed: {}", err] + ); + gst::FlowError::Error + })?; + + let (future, abort_handle) = abortable(self.sync_and_send(element, buffer)); + + self.state.lock().unwrap().send_abort_handle = Some(abort_handle); + + let ret = match RUNTIME.enter(|| futures::executor::block_on(future)) { + Err(err) => Err(gst::FlowError::Flushing), + Ok(res) => res, + }; + + ret + } + + fn sink_chain( + &self, + pad: &gst::Pad, + element: &gst::Element, + buffer: gst::Buffer, + ) -> Result { + self.handle_buffer(pad, element, Some(buffer)) + } + + fn ensure_connection(&self, element: &gst::Element) -> Result<(), gst::ErrorMessage> { + let mut state = self.state.lock().unwrap(); + + if state.connected == true { + return Ok(()); + } + + let in_caps = self.sinkpad.get_current_caps().unwrap(); + let s = in_caps.get_structure(0).unwrap(); + let sample_rate: i32 = s.get("rate").unwrap().unwrap(); + + let settings = self.settings.lock().unwrap(); + + gst_info!(CAT, obj: element, "Connecting .."); + + let creds = RUNTIME + .enter(|| futures::executor::block_on(EnvironmentProvider::default().credentials())) + .map_err(|err| { + gst_error!(CAT, obj: element, "Failed to generate credentials: {}", err); + gst_error_msg!( + gst::CoreError::Failed, + ["Failed to generate credentials: {}", err] + ) + })?; + + let language_code = settings + .language_code + .as_ref() + .expect("Language code is required"); + + let request = StartTranscriptionJobRequest { + job_execution_settings: None, + language_code: language_code.to_string(), + media: Media { + media_file_uri: None, + }, + media_format: Some("wav".to_string()), + media_sample_rate_hertz: Some(sample_rate.into()), + output_bucket_name: None, + output_encryption_kms_key_id: None, + settings: Some(TranscriptionSettings { + channel_identification: None, + max_alternatives: None, + max_speaker_labels: None, + show_alternatives: Some(false), + show_speaker_labels: None, + vocabulary_filter_method: None, + vocabulary_filter_name: None, + vocabulary_name: None, + }), + transcription_job_name: element.get_name().to_string(), + }; + + let region = Region::UsEast1; + + let mut signed = SignedRequest::new( + "GET", + "transcribe", + ®ion, + "/stream-transcription-websocket", + ); + signed.set_hostname(Some(format!( + "transcribestreaming.{}.amazonaws.com:8443", + region.name() + ))); + signed.add_param("language-code", language_code); + signed.add_param("media-encoding", "pcm"); + signed.add_param("sample-rate", &sample_rate.to_string()); + let url = signed.generate_presigned_url(&creds, &std::time::Duration::from_secs(60), true); + + let (mut ws, _) = RUNTIME + .enter(|| futures::executor::block_on(connect_async(format!("wss{}", &url[5..])))) + .map_err(|err| { + gst_error!(CAT, obj: element, "Failed to connect: {}", err); + gst_error_msg!(gst::CoreError::Failed, ["Failed to connect: {}", err]) + })?; + + let (ws_sink, mut ws_stream) = ws.split(); + + *self.ws_sink.lock().unwrap() = Some(Box::pin(ws_sink)); + + let element_weak = element.downgrade(); + let future = async move { + loop { + let element = match element_weak.upgrade() { + Some(element) => element, + None => break, + }; + let transcribe = Self::from_instance(&element); + let msg = match ws_stream.next().await { + Some(msg) => msg, + None => { + let mut state = transcribe.state.lock().unwrap(); + state.send_eos = true; + break; + } + }; + + let msg = match msg { + Ok(msg) => msg, + Err(err) => { + gst_error!(CAT, "Failed to receive data: {}", err); + gst_element_error!( + element, + gst::StreamError::Failed, + ["Streaming failed: {}", err] + ); + break; + } + }; + + let mut sender = transcribe.state.lock().unwrap().sender.clone(); + + if let Some(sender) = sender.as_mut() { + sender.send(msg).await; + } + } + }; + + let (future, abort_handle) = abortable(future); + + state.recv_abort_handle = Some(abort_handle); + + RUNTIME.spawn(future); + + state.connected = true; + + gst_info!(CAT, obj: element, "Connected"); + + Ok(()) + } + + fn disconnect(&self, element: &gst::Element) -> Result<(), gst::ErrorMessage> { + let mut state = self.state.lock().unwrap(); + + gst_info!(CAT, obj: element, "Unpreparing"); + + if let Some(abort_handle) = state.recv_abort_handle.take() { + abort_handle.abort(); + } + + if let Some(abort_handle) = state.send_abort_handle.take() { + abort_handle.abort(); + } + + *state = State::default(); + + gst_info!( + CAT, + obj: element, + "Unprepared, connected: {}!", + state.connected + ); + + Ok(()) + } +} + +impl ObjectSubclass for Transcriber { + const NAME: &'static str = "RsAwsTranscriber"; + type ParentType = gst::Element; + type Instance = gst::subclass::ElementInstanceStruct; + type Class = subclass::simple::ClassStruct; + + glib_object_subclass!(); + + fn new_with_class(klass: &subclass::simple::ClassStruct) -> Self { + let templ = klass.get_pad_template("sink").unwrap(); + let sinkpad = gst::Pad::new_from_template(&templ, Some("sink")); + let templ = klass.get_pad_template("src").unwrap(); + let srcpad = gst::Pad::new_from_template(&templ, Some("src")); + + srcpad.use_fixed_caps(); + + Transcriber::set_pad_functions(&sinkpad, &srcpad); + let settings = Mutex::new(Settings::default()); + + Self { + srcpad, + sinkpad, + settings, + state: Mutex::new(State::default()), + ws_sink: Mutex::new(None), + } + } + + fn class_init(klass: &mut subclass::simple::ClassStruct) { + klass.set_metadata( + "Transcriber", + "Audio/Text/Filter", + "Speech to Text filter, using AWS transcribe", + "Jordan Petridis , Mathieu Duponchelle ", + ); + + let src_caps = gst::Caps::builder("text/x-raw") + .field("format", &"utf8") + .build(); + let src_pad_template = gst::PadTemplate::new( + "src", + gst::PadDirection::Src, + gst::PadPresence::Always, + &src_caps, + ) + .unwrap(); + klass.add_pad_template(src_pad_template); + + let sink_caps = gst::Caps::builder("audio/x-raw") + .field("format", &"S16LE") + .field("rate", &gst::IntRange::::new(8000, 48000)) + .field("channels", &1) + .build(); + let sink_pad_template = gst::PadTemplate::new( + "sink", + gst::PadDirection::Sink, + gst::PadPresence::Always, + &sink_caps, + ) + .unwrap(); + klass.add_pad_template(sink_pad_template); + klass.install_properties(&PROPERTIES); + } +} + +impl ObjectImpl for Transcriber { + glib_object_impl!(); + + fn constructed(&self, obj: &glib::Object) { + self.parent_constructed(obj); + + let element = obj.downcast_ref::().unwrap(); + element.add_pad(&self.sinkpad).unwrap(); + element.add_pad(&self.srcpad).unwrap(); + } + + fn set_property(&self, _obj: &glib::Object, id: usize, value: &glib::Value) { + let prop = &PROPERTIES[id]; + + match *prop { + subclass::Property("language_code", ..) => { + let mut settings = self.settings.lock().unwrap(); + settings.language_code = value.get().expect("type checked upstream"); + } + subclass::Property("latency", ..) => { + let mut settings = self.settings.lock().unwrap(); + settings.latency_ms = value.get_some().expect("type checked upstream"); + } + _ => unimplemented!(), + } + } + + fn get_property(&self, _obj: &glib::Object, id: usize) -> Result { + let prop = &PROPERTIES[id]; + + match *prop { + subclass::Property("language-code", ..) => { + let settings = self.settings.lock().unwrap(); + Ok(settings.language_code.to_value()) + } + subclass::Property("latency", ..) => { + let settings = self.settings.lock().unwrap(); + Ok(settings.latency_ms.to_value()) + } + _ => unimplemented!(), + } + } +} + +impl ElementImpl for Transcriber { + fn change_state( + &self, + element: &gst::Element, + transition: gst::StateChange, + ) -> Result { + gst_info!(CAT, obj: element, "Changing state {:?}", transition); + + match transition { + gst::StateChange::PausedToReady => { + self.disconnect(element).map_err(|err| { + element.post_error_message(&err); + gst::StateChangeError + })?; + } + _ => (), + } + + let mut success = self.parent_change_state(element, transition)?; + + match transition { + gst::StateChange::ReadyToPaused => { + success = gst::StateChangeSuccess::NoPreroll; + } + gst::StateChange::PlayingToPaused => { + success = gst::StateChangeSuccess::NoPreroll; + } + _ => (), + } + + Ok(success) + } +} + +pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { + gst::Element::register( + Some(plugin), + "awstranscriber", + gst::Rank::None, + Transcriber::get_type(), + ) +}