mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-22 03:21:00 +00:00
gst-plugin-transcribe: address latest review comments
This commit is contained in:
parent
a31b3c5c83
commit
7c79f73a4c
2 changed files with 62 additions and 24 deletions
|
@ -38,7 +38,6 @@ use futures::future::{abortable, AbortHandle};
|
|||
use futures::prelude::*;
|
||||
use tokio::runtime;
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Mutex;
|
||||
|
@ -89,6 +88,12 @@ struct Transcript {
|
|||
transcript: TranscriptTranscript,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ExceptionMessage {
|
||||
message: String,
|
||||
}
|
||||
|
||||
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
|
||||
gst::DebugCategory::new(
|
||||
"awstranscribe",
|
||||
|
@ -192,18 +197,18 @@ struct Transcriber {
|
|||
fn build_packet(payload: &[u8]) -> Vec<u8> {
|
||||
let headers = [
|
||||
Header {
|
||||
name: Cow::Borrowed(":event-type"),
|
||||
value: Cow::Borrowed("AudioEvent"),
|
||||
name: ":event-type".into(),
|
||||
value: "AudioEvent".into(),
|
||||
value_type: 7,
|
||||
},
|
||||
Header {
|
||||
name: Cow::Borrowed(":content-type"),
|
||||
value: Cow::Borrowed("application/octet-stream"),
|
||||
name: ":content-type".into(),
|
||||
value: "application/octet-stream".into(),
|
||||
value_type: 7,
|
||||
},
|
||||
Header {
|
||||
name: Cow::Borrowed(":message-type"),
|
||||
value: Cow::Borrowed("event"),
|
||||
name: ":message-type".into(),
|
||||
value: "event".into(),
|
||||
value_type: 7,
|
||||
},
|
||||
];
|
||||
|
@ -211,18 +216,6 @@ fn build_packet(payload: &[u8]) -> Vec<u8> {
|
|||
encode_packet(payload, &headers).expect("foobar")
|
||||
}
|
||||
|
||||
fn get_current_running_time(element: &gst::Element) -> gst::ClockTime {
|
||||
if let Some(clock) = element.get_clock() {
|
||||
if clock.get_time() > element.get_base_time() {
|
||||
clock.get_time() - element.get_base_time()
|
||||
} else {
|
||||
0.into()
|
||||
}
|
||||
} else {
|
||||
gst::CLOCK_TIME_NONE
|
||||
}
|
||||
}
|
||||
|
||||
impl Transcriber {
|
||||
fn set_pad_functions(sinkpad: &gst::Pad, srcpad: &gst::Pad) {
|
||||
sinkpad.set_chain_function(|pad, parent, buffer| {
|
||||
|
@ -272,7 +265,7 @@ impl Transcriber {
|
|||
let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64
|
||||
- GRANULARITY_MS as u64)
|
||||
* gst::MSECOND;
|
||||
let now = get_current_running_time(element);
|
||||
let now = element.get_current_running_time();
|
||||
|
||||
while let Some(buf) = state.buffers.front() {
|
||||
if now - buf.get_pts() > latency {
|
||||
|
@ -379,8 +372,43 @@ impl Transcriber {
|
|||
|
||||
match msg {
|
||||
Message::Binary(buf) => {
|
||||
let (_, pkt) = parse_packet(&buf).unwrap();
|
||||
let (_, pkt) = parse_packet(&buf).map_err(|err| {
|
||||
gst_error!(CAT, obj: element, "Failed to parse packet: {}", err);
|
||||
gst_error_msg!(
|
||||
gst::StreamError::Failed,
|
||||
["Failed to parse packet: {}", err]
|
||||
)
|
||||
})?;
|
||||
|
||||
let payload = std::str::from_utf8(pkt.payload).unwrap();
|
||||
|
||||
if packet_is_exception(&pkt) {
|
||||
let message: ExceptionMessage =
|
||||
serde_json::from_str(&payload).map_err(|err| {
|
||||
gst_error!(
|
||||
CAT,
|
||||
obj: element,
|
||||
"Unexpected exception message: {} ({})",
|
||||
payload,
|
||||
err
|
||||
);
|
||||
gst_error_msg!(
|
||||
gst::StreamError::Failed,
|
||||
["Unexpected exception message: {} ({})", payload, err]
|
||||
)
|
||||
})?;
|
||||
gst_error!(
|
||||
CAT,
|
||||
obj: element,
|
||||
"AWS raised an error: {}",
|
||||
message.message
|
||||
);
|
||||
Err(gst_error_msg!(
|
||||
gst::StreamError::Failed,
|
||||
["AWS raised an error: {}", message.message]
|
||||
))?;
|
||||
}
|
||||
|
||||
let mut transcript: Transcript =
|
||||
serde_json::from_str(&payload).map_err(|err| {
|
||||
gst_error_msg!(
|
||||
|
@ -683,7 +711,7 @@ impl Transcriber {
|
|||
|
||||
if let Some(buffer) = &buffer {
|
||||
let running_time = state.in_segment.to_running_time(buffer.get_pts());
|
||||
let now = get_current_running_time(&element);
|
||||
let now = element.get_current_running_time();
|
||||
|
||||
if now.is_some() && now < running_time {
|
||||
delay = Some(running_time - now);
|
||||
|
|
|
@ -126,14 +126,24 @@ fn parse_header(input: &[u8]) -> IResult<&[u8], Header> {
|
|||
let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?;
|
||||
|
||||
let header = Header {
|
||||
name: Cow::Owned(name.to_string()),
|
||||
name: name.to_string().into(),
|
||||
value_type,
|
||||
value: Cow::Owned(value.to_string()),
|
||||
value: value.to_string().into(),
|
||||
};
|
||||
|
||||
Ok((input, header))
|
||||
}
|
||||
|
||||
pub fn packet_is_exception(packet: &Packet) -> bool {
|
||||
for header in &packet.headers {
|
||||
if header.name == ":message-type" && header.value_type == 7 {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> {
|
||||
let (remainder, prelude) = parse_prelude(input)?;
|
||||
|
||||
|
|
Loading…
Reference in a new issue