gst-plugin-transcribe: address latest review comments

This commit is contained in:
Mathieu Duponchelle 2020-04-10 23:51:19 +02:00 committed by Sebastian Dröge
parent a31b3c5c83
commit 7c79f73a4c
2 changed files with 62 additions and 24 deletions

View file

@ -38,7 +38,6 @@ use futures::future::{abortable, AbortHandle};
use futures::prelude::*;
use tokio::runtime;
use std::borrow::Cow;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Mutex;
@ -89,6 +88,12 @@ struct Transcript {
transcript: TranscriptTranscript,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "PascalCase")]
struct ExceptionMessage {
message: String,
}
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
gst::DebugCategory::new(
"awstranscribe",
@ -192,18 +197,18 @@ struct Transcriber {
fn build_packet(payload: &[u8]) -> Vec<u8> {
let headers = [
Header {
name: Cow::Borrowed(":event-type"),
value: Cow::Borrowed("AudioEvent"),
name: ":event-type".into(),
value: "AudioEvent".into(),
value_type: 7,
},
Header {
name: Cow::Borrowed(":content-type"),
value: Cow::Borrowed("application/octet-stream"),
name: ":content-type".into(),
value: "application/octet-stream".into(),
value_type: 7,
},
Header {
name: Cow::Borrowed(":message-type"),
value: Cow::Borrowed("event"),
name: ":message-type".into(),
value: "event".into(),
value_type: 7,
},
];
@ -211,18 +216,6 @@ fn build_packet(payload: &[u8]) -> Vec<u8> {
encode_packet(payload, &headers).expect("foobar")
}
fn get_current_running_time(element: &gst::Element) -> gst::ClockTime {
if let Some(clock) = element.get_clock() {
if clock.get_time() > element.get_base_time() {
clock.get_time() - element.get_base_time()
} else {
0.into()
}
} else {
gst::CLOCK_TIME_NONE
}
}
impl Transcriber {
fn set_pad_functions(sinkpad: &gst::Pad, srcpad: &gst::Pad) {
sinkpad.set_chain_function(|pad, parent, buffer| {
@ -272,7 +265,7 @@ impl Transcriber {
let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64
- GRANULARITY_MS as u64)
* gst::MSECOND;
let now = get_current_running_time(element);
let now = element.get_current_running_time();
while let Some(buf) = state.buffers.front() {
if now - buf.get_pts() > latency {
@ -379,8 +372,43 @@ impl Transcriber {
match msg {
Message::Binary(buf) => {
let (_, pkt) = parse_packet(&buf).unwrap();
let (_, pkt) = parse_packet(&buf).map_err(|err| {
gst_error!(CAT, obj: element, "Failed to parse packet: {}", err);
gst_error_msg!(
gst::StreamError::Failed,
["Failed to parse packet: {}", err]
)
})?;
let payload = std::str::from_utf8(pkt.payload).unwrap();
if packet_is_exception(&pkt) {
let message: ExceptionMessage =
serde_json::from_str(&payload).map_err(|err| {
gst_error!(
CAT,
obj: element,
"Unexpected exception message: {} ({})",
payload,
err
);
gst_error_msg!(
gst::StreamError::Failed,
["Unexpected exception message: {} ({})", payload, err]
)
})?;
gst_error!(
CAT,
obj: element,
"AWS raised an error: {}",
message.message
);
Err(gst_error_msg!(
gst::StreamError::Failed,
["AWS raised an error: {}", message.message]
))?;
}
let mut transcript: Transcript =
serde_json::from_str(&payload).map_err(|err| {
gst_error_msg!(
@ -683,7 +711,7 @@ impl Transcriber {
if let Some(buffer) = &buffer {
let running_time = state.in_segment.to_running_time(buffer.get_pts());
let now = get_current_running_time(&element);
let now = element.get_current_running_time();
if now.is_some() && now < running_time {
delay = Some(running_time - now);

View file

@ -126,14 +126,24 @@ fn parse_header(input: &[u8]) -> IResult<&[u8], Header> {
let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?;
let header = Header {
name: Cow::Owned(name.to_string()),
name: name.to_string().into(),
value_type,
value: Cow::Owned(value.to_string()),
value: value.to_string().into(),
};
Ok((input, header))
}
pub fn packet_is_exception(packet: &Packet) -> bool {
for header in &packet.headers {
if header.name == ":message-type" && header.value_type == 7 {
return true;
}
}
false
}
pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> {
let (remainder, prelude) = parse_prelude(input)?;