diff --git a/audio/transcribe/Cargo.toml b/audio/transcribe/Cargo.toml index 9e5667c5..a36ea934 100644 --- a/audio/transcribe/Cargo.toml +++ b/audio/transcribe/Cargo.toml @@ -1,18 +1,16 @@ [package] name = "gst-plugin-transcribe" version = "0.1.0" -authors = ["Jordan Petridis "] +authors = ["Jordan Petridis ", "Mathieu Duponchelle "] edition = "2018" description = "AWS Transcribe plugin" repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs" -# FIXME: licence +license = "LGPL-2.1+" [dependencies] glib = { git = "https://github.com/gtk-rs/glib" } -# FIXME: -# gst = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"], package = "gstreamer" } -gstreamer = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] } -gst_base = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"], package ="gstreamer-base" } +gstreamer = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } +gstreamer-base = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } rusoto_core = "0.43.0-beta.1" rusoto_credential = "0.43.0-beta.1" rusoto_transcribe = "0.43.0-beta.1" @@ -24,7 +22,7 @@ async-tungstenite = { version = "0.4", features = ["tokio", "tokio-runtime", "to nom = "5.1.1" crc = "1.8.1" byteorder = "1.3.4" -lazy_static = "1.4.0" +once_cell = "1.0" serde = "1" serde_derive = "1" serde_json = "1" diff --git a/audio/transcribe/src/aws_transcribe_parse.rs b/audio/transcribe/src/aws_transcribe_parse.rs index 7a7f5815..c7e7d071 100644 --- a/audio/transcribe/src/aws_transcribe_parse.rs +++ b/audio/transcribe/src/aws_transcribe_parse.rs @@ -1,4 +1,19 @@ -#![allow(unused)] +// Copyright (C) 2020 Mathieu Duponchelle +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Library General Public +// License as published by the Free Software Foundation; either +// version 2 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Library General Public License for more details. +// +// You should have received a copy of the GNU Library General Public +// License along with this library; if not, write to the +// Free Software Foundation, Inc., 51 Franklin Street, Suite 500, +// Boston, MA 02110-1335, USA. use glib; use glib::prelude::*; @@ -8,45 +23,32 @@ use gst; use gst::prelude::*; use gst::subclass::prelude::*; -use std::convert::TryInto; use std::default::Default; -use std::env; use rusoto_core::Region; use rusoto_credential; use rusoto_credential::{EnvironmentProvider, ProvideAwsCredentials}; -use rusoto_transcribe; use rusoto_signature::signature::SignedRequest; -use rusoto_signature::signature::SignedRequestPayload; -use rusoto_signature::stream::ByteStream; -use rusoto_transcribe::Media; -use rusoto_transcribe::Settings as TranscriptionSettings; -use rusoto_transcribe::StartTranscriptionJobRequest; -use rusoto_transcribe::Transcribe; use async_tungstenite::tungstenite::error::Error as WsError; use async_tungstenite::{tokio::connect_async, tungstenite::Message}; use futures::channel::mpsc; use futures::future::{abortable, AbortHandle}; -use futures::io::{AsyncReadExt, Cursor}; use futures::prelude::*; -use futures::stream::SplitSink; use tokio::runtime; -use crc::crc32; - -use std::boxed::Box; +use std::borrow::Cow; use std::collections::VecDeque; use std::pin::Pin; -use std::sync::{Mutex, MutexGuard}; +use std::sync::Mutex; use std::time::Duration; -use lazy_static; - use crate::packet::*; -use serde_derive::{Deserialize, Serialize}; +use serde_derive::Deserialize; + +use once_cell::sync::Lazy; #[derive(Deserialize, Debug)] #[serde(rename_all = "PascalCase")] @@ -87,21 +89,22 @@ struct Transcript { transcript: TranscriptTranscript, } -lazy_static! { - static ref CAT: gst::DebugCategory = { - gst::DebugCategory::new( - "awstranscribe", - gst::DebugColorFlags::empty(), - Some("AWS Transcribe element"), - ) - }; - static ref RUNTIME: runtime::Runtime = runtime::Builder::new() +static CAT: Lazy = Lazy::new(|| { + gst::DebugCategory::new( + "awstranscribe", + gst::DebugColorFlags::empty(), + Some("AWS Transcribe element"), + ) +}); + +static RUNTIME: Lazy = Lazy::new(|| { + runtime::Builder::new() .threaded_scheduler() .enable_all() .core_threads(1) .build() - .unwrap(); -} + .unwrap() +}); const DEFAULT_LATENCY_MS: u32 = 30000; const GRANULARITY_MS: u32 = 100; @@ -110,7 +113,7 @@ static PROPERTIES: [subclass::Property; 2] = [ subclass::Property("language-code", |name| { glib::ParamSpec::string( name, - "Language-Code", + "Language Code", "The Language of the Stream, see \ \ for an up to date list of allowed languages", @@ -121,7 +124,7 @@ static PROPERTIES: [subclass::Property; 2] = [ subclass::Property("latency", |name| { glib::ParamSpec::uint( name, - "latency in ms", + "Latency", "Amount of milliseconds to allow AWS transcribe", GRANULARITY_MS, std::u32::MAX, @@ -176,29 +179,31 @@ impl Default for State { } } +type WsSink = Pin + Send>>; + struct Transcriber { srcpad: gst::Pad, sinkpad: gst::Pad, settings: Mutex, state: Mutex, - ws_sink: Mutex + Send>>>>, + ws_sink: Mutex>, } fn build_packet(payload: &[u8]) -> Vec { let headers = [ Header { - name: String::from(":event-type"), - value: String::from("AudioEvent"), + name: Cow::Borrowed(":event-type"), + value: Cow::Borrowed("AudioEvent"), value_type: 7, }, Header { - name: String::from(":content-type"), - value: String::from("application/octet-stream"), + name: Cow::Borrowed(":content-type"), + value: Cow::Borrowed("application/octet-stream"), value_type: 7, }, Header { - name: String::from(":message-type"), - value: String::from("event"), + name: Cow::Borrowed(":message-type"), + value: Cow::Borrowed("event"), value_type: 7, }, ]; @@ -256,32 +261,27 @@ impl Transcriber { }); } - fn push_gap(&self, element: &gst::Element) -> bool { + fn dequeue(&self, element: &gst::Element) -> bool { /* First, check our pending buffers */ - let mut items: Vec = vec![]; + let mut items = vec![]; let (latency, now, mut last_position, send_eos, seqnum) = { let mut state = self.state.lock().unwrap(); - let send_eos = state.send_eos && state.buffers.len() == 0; + let send_eos = state.send_eos && state.buffers.is_empty(); - let latency: gst::ClockTime = ((self.settings.lock().unwrap().latency_ms as u64 + let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64 - GRANULARITY_MS as u64) - * 1000000) - .into(); + * gst::MSECOND; let now = get_current_running_time(element); - loop { - if let Some(buf) = state.buffers.front() { - if now - buf.get_pts() > latency { - /* Safe unwrap, we know we have an item */ - let buf = state.buffers.pop_front().unwrap(); - items.push(buf) - } else { - break; - } + while let Some(buf) = state.buffers.front() { + if now - buf.get_pts() > latency { + /* Safe unwrap, we know we have an item */ + let buf = state.buffers.pop_front().unwrap(); + items.push(buf); + } else { + break; } - - break; } ( @@ -295,7 +295,8 @@ impl Transcriber { /* We're EOS, we can pause and exit early */ if send_eos { - self.srcpad.pause_task(); + let _ = self.srcpad.pause_task(); + return self .srcpad .push_event(gst::Event::new_eos().seqnum(seqnum).build()); @@ -318,7 +319,7 @@ impl Transcriber { } last_position = buf.get_pts() + buf.get_duration(); { - let mut buf = buf.get_mut().unwrap(); + let buf = buf.get_mut().unwrap(); buf.set_pts(buf.get_pts()); } gst_debug!( @@ -327,7 +328,7 @@ impl Transcriber { buf.get_pts(), buf.get_pts() + buf.get_duration() ); - if !self.srcpad.push(buf).is_ok() { + if self.srcpad.push(buf).is_err() { return false; } } @@ -371,32 +372,33 @@ impl Transcriber { Some(msg) => msg, /* Sender was closed */ None => { - self.srcpad.pause_task(); + let _ = self.srcpad.pause_task(); return Ok(()); } }; match msg { Message::Binary(buf) => { - let (data, pkt) = parse_packet(&buf).unwrap(); + let (_, pkt) = parse_packet(&buf).unwrap(); let payload = std::str::from_utf8(pkt.payload).unwrap(); - let transcript: Transcript = serde_json::from_str(&payload).map_err(|err| { - gst_error_msg!( - gst::StreamError::Failed, - ["Unexpected binary message: {} ({})", payload, err] - ) - })?; + let mut transcript: Transcript = + serde_json::from_str(&payload).map_err(|err| { + gst_error_msg!( + gst::StreamError::Failed, + ["Unexpected binary message: {} ({})", payload, err] + ) + })?; - if transcript.transcript.results.len() > 0 { - let result = &transcript.transcript.results[0]; - if !result.is_partial && result.alternatives.len() > 0 { - let alternative = &result.alternatives[0]; + if !transcript.transcript.results.is_empty() { + let mut result = transcript.transcript.results.remove(0); + if !result.is_partial && !result.alternatives.is_empty() { + let alternative = result.alternatives.remove(0); gst_info!(CAT, obj: element, "Transcript: {}", alternative.transcript); let mut start_time: gst::ClockTime = - ((result.start_time as f64 * 1000000000.0) as u64).into(); + ((result.start_time as f64 * 1_000_000_000.0) as u64).into(); let end_time: gst::ClockTime = - ((result.end_time as f64 * 1000000000.0) as u64).into(); + ((result.end_time as f64 * 1_000_000_000.0) as u64).into(); let mut state = self.state.lock().unwrap(); let position = state.out_segment.get_position(); @@ -414,11 +416,11 @@ impl Transcriber { } let mut buf = gst::Buffer::from_mut_slice( - alternative.transcript.as_bytes().to_vec(), + alternative.transcript.into_bytes(), ); { - let mut buf = buf.get_mut().unwrap(); + let buf = buf.get_mut().unwrap(); if state.discont { buf.set_flags(gst::BufferFlags::DISCONT); @@ -446,16 +448,18 @@ impl Transcriber { let future = async move { match tokio::time::timeout(Duration::from_millis(GRANULARITY_MS.into()), future).await { Err(_) => { - if !self.push_gap(element) { + if !self.dequeue(element) { gst_info!(CAT, obj: element, "Failed to push gap event, pausing"); - self.srcpad.pause_task(); + + let _ = self.srcpad.pause_task(); } Ok(()) } Ok(res) => { - if !self.push_gap(element) { + if !self.dequeue(element) { gst_info!(CAT, obj: element, "Failed to push gap event, pausing"); - self.srcpad.pause_task(); + + let _ = self.srcpad.pause_task(); } res } @@ -480,21 +484,24 @@ impl Transcriber { Some(element) => element, None => { if let Some(pad) = pad_weak.upgrade() { - pad.pause_task().unwrap(); + let _ = pad.pause_task(); } return; } }; let transcribe = Self::from_instance(&element); - transcribe.loop_fn(&element, &mut receiver).map_err(|err| { - gst_element_error!( - &element, - gst::StreamError::Failed, - ["Streaming failed: {}", err] - ); - transcribe.srcpad.pause_task().unwrap(); - }); + match transcribe.loop_fn(&element, &mut receiver) { + Err(err) => { + gst_element_error!( + &element, + gst::StreamError::Failed, + ["Streaming failed: {}", err] + ); + let _ = transcribe.srcpad.pause_task(); + } + Ok(_) => (), + }; }); if res.is_err() { return Err(gst_loggable_error!(CAT, "Failed to start pad task")); @@ -506,7 +513,7 @@ impl Transcriber { &self, _pad: &gst::Pad, element: &gst::Element, - mode: gst::PadMode, + _mode: gst::PadMode, active: bool, ) -> Result<(), gst::LoggableError> { if active { @@ -537,7 +544,7 @@ impl Transcriber { if ret { let (_, min, _) = peer_query.get_result(); let our_latency: gst::ClockTime = - (self.settings.lock().unwrap().latency_ms as u64 * 1000000).into(); + self.settings.lock().unwrap().latency_ms as u64 * gst::MSECOND; q.set(true, our_latency + min, gst::CLOCK_TIME_NONE); } ret @@ -547,8 +554,8 @@ impl Transcriber { let state = self.state.lock().unwrap(); q.set( state - .in_segment - .to_stream_time(state.in_segment.get_position()), + .out_segment + .to_stream_time(state.out_segment.get_position()), ); true } else { @@ -572,18 +579,47 @@ impl Transcriber { } Ok(_) => true, }, - EventView::FlushStart(e) => { + EventView::FlushStart(_) => { gst_info!(CAT, obj: element, "Received flush start, disconnecting"); - self.disconnect(element); - let ret = pad.event_default(Some(element), event); - self.srcpad.stop_task(); - ret + match self.disconnect(element) { + Err(err) => { + element.post_error_message(&err); + false + } + Ok(_) => { + let mut ret = pad.event_default(Some(element), event); + + match self.srcpad.stop_task() { + Err(err) => { + gst_error!( + CAT, + obj: element, + "Failed to stop srcpad task: {}", + err + ); + ret = false; + } + Ok(_) => (), + }; + + ret + } + } } - EventView::FlushStop(e) => { + EventView::FlushStop(_) => { gst_info!(CAT, obj: element, "Received flush stop, restarting task"); - let ret = pad.event_default(Some(element), event); - self.start_task(element); - ret + + if pad.event_default(Some(element), event) { + match self.start_task(element) { + Err(err) => { + gst_error!(CAT, obj: element, "Failed to start srcpad task: {}", err); + false + } + Ok(_) => true, + } + } else { + false + } } EventView::Segment(e) => { let segment = match e.get_segment().clone().downcast::() { @@ -684,7 +720,7 @@ impl Transcriber { fn handle_buffer( &self, - pad: &gst::Pad, + _pad: &gst::Pad, element: &gst::Element, buffer: Option, ) -> Result { @@ -703,12 +739,10 @@ impl Transcriber { self.state.lock().unwrap().send_abort_handle = Some(abort_handle); - let ret = match RUNTIME.enter(|| futures::executor::block_on(future)) { - Err(err) => Err(gst::FlowError::Flushing), + match RUNTIME.enter(|| futures::executor::block_on(future)) { + Err(_) => Err(gst::FlowError::Flushing), Ok(res) => res, - }; - - ret + } } fn sink_chain( @@ -723,7 +757,7 @@ impl Transcriber { fn ensure_connection(&self, element: &gst::Element) -> Result<(), gst::ErrorMessage> { let mut state = self.state.lock().unwrap(); - if state.connected == true { + if state.connected { return Ok(()); } @@ -750,29 +784,6 @@ impl Transcriber { .as_ref() .expect("Language code is required"); - let request = StartTranscriptionJobRequest { - job_execution_settings: None, - language_code: language_code.to_string(), - media: Media { - media_file_uri: None, - }, - media_format: Some("wav".to_string()), - media_sample_rate_hertz: Some(sample_rate.into()), - output_bucket_name: None, - output_encryption_kms_key_id: None, - settings: Some(TranscriptionSettings { - channel_identification: None, - max_alternatives: None, - max_speaker_labels: None, - show_alternatives: Some(false), - show_speaker_labels: None, - vocabulary_filter_method: None, - vocabulary_filter_name: None, - vocabulary_name: None, - }), - transcription_job_name: element.get_name().to_string(), - }; - let region = Region::UsEast1; let mut signed = SignedRequest::new( @@ -790,7 +801,7 @@ impl Transcriber { signed.add_param("sample-rate", &sample_rate.to_string()); let url = signed.generate_presigned_url(&creds, &std::time::Duration::from_secs(60), true); - let (mut ws, _) = RUNTIME + let (ws, _) = RUNTIME .enter(|| futures::executor::block_on(connect_async(format!("wss{}", &url[5..])))) .map_err(|err| { gst_error!(CAT, obj: element, "Failed to connect: {}", err); @@ -803,11 +814,7 @@ impl Transcriber { let element_weak = element.downgrade(); let future = async move { - loop { - let element = match element_weak.upgrade() { - Some(element) => element, - None => break, - }; + while let Some(element) = element_weak.upgrade() { let transcribe = Self::from_instance(&element); let msg = match ws_stream.next().await { Some(msg) => msg, @@ -834,7 +841,9 @@ impl Transcriber { let mut sender = transcribe.state.lock().unwrap().sender.clone(); if let Some(sender) = sender.as_mut() { - sender.send(msg).await; + if sender.send(msg).await.is_err() { + break; + } } } }; diff --git a/audio/transcribe/src/lib.rs b/audio/transcribe/src/lib.rs index 00ffacc3..817e354b 100644 --- a/audio/transcribe/src/lib.rs +++ b/audio/transcribe/src/lib.rs @@ -1,10 +1,22 @@ -// FIXME: add lgpl 2.1 license +// Copyright (C) 2020 Mathieu Duponchelle +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Library General Public +// License as published by the Free Software Foundation; either +// version 2 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Library General Public License for more details. +// +// You should have received a copy of the GNU Library General Public +// License along with this library; if not, write to the +// Free Software Foundation, Inc., 51 Franklin Street, Suite 500, +// Boston, MA 02110-1335, USA. #![crate_type = "cdylib"] -#[macro_use] -extern crate lazy_static; - #[macro_use] extern crate glib; #[macro_use] diff --git a/audio/transcribe/src/packet.rs b/audio/transcribe/src/packet.rs index 22406e80..2cca7be3 100644 --- a/audio/transcribe/src/packet.rs +++ b/audio/transcribe/src/packet.rs @@ -1,30 +1,47 @@ +// Copyright (C) 2020 Jordan Petridis +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Library General Public +// License as published by the Free Software Foundation; either +// version 2 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Library General Public License for more details. +// +// You should have received a copy of the GNU Library General Public +// License along with this library; if not, write to the +// Free Software Foundation, Inc., 51 Franklin Street, Suite 500, +// Boston, MA 02110-1335, USA. + use byteorder::{BigEndian, WriteBytesExt}; use crc::crc32; use nom::{ self, bytes::complete::take, combinator::map_res, multi::many0, number::complete::be_u16, number::complete::be_u32, number::complete::be_u8, sequence::tuple, IResult, }; +use std::borrow::Cow; use std::io::{self, Write}; #[derive(Debug)] -pub struct Prelude { +struct Prelude { total_bytes: u32, header_bytes: u32, prelude_crc: u32, } #[derive(Debug)] -// FIXME: make private pub struct Header { - pub name: String, + pub name: Cow<'static, str>, pub value_type: u8, - pub value: String, + pub value: Cow<'static, str>, } #[derive(Debug)] pub struct Packet<'a> { prelude: Prelude, - pub headers: Vec
, + headers: Vec
, pub payload: &'a [u8], msg_crc: u32, } @@ -109,9 +126,9 @@ fn parse_header(input: &[u8]) -> IResult<&[u8], Header> { let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?; let header = Header { - name: name.to_string(), + name: Cow::Owned(name.to_string()), value_type, - value: value.to_string(), + value: Cow::Owned(value.to_string()), }; Ok((input, header)) @@ -125,7 +142,6 @@ pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> { let (_, msg_crc) = be_u32(&input[input.len() - 4..])?; if msg_crc != sum { - // FIXME: a better errortype than mapres return Err(nom::Err::Error(( b"Prelude CRC doesn't match", nom::error::ErrorKind::MapRes, @@ -134,7 +150,6 @@ pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> { let (remainder, header_input) = take(prelude.header_bytes)(remainder)?; let (_, headers) = many0(parse_header)(header_input)?; - //dbg!(&headers); let payload_length = prelude.total_bytes - prelude.header_bytes - 4 - 12; let (remainder, payload) = take(payload_length)(remainder)?;