mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-25 21:11:00 +00:00
gst-plugin-transcribe: address latest review comments
This commit is contained in:
parent
a31b3c5c83
commit
7c79f73a4c
2 changed files with 62 additions and 24 deletions
|
@ -38,7 +38,6 @@ use futures::future::{abortable, AbortHandle};
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use tokio::runtime;
|
use tokio::runtime;
|
||||||
|
|
||||||
use std::borrow::Cow;
|
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
@ -89,6 +88,12 @@ struct Transcript {
|
||||||
transcript: TranscriptTranscript,
|
transcript: TranscriptTranscript,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
#[serde(rename_all = "PascalCase")]
|
||||||
|
struct ExceptionMessage {
|
||||||
|
message: String,
|
||||||
|
}
|
||||||
|
|
||||||
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
|
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
|
||||||
gst::DebugCategory::new(
|
gst::DebugCategory::new(
|
||||||
"awstranscribe",
|
"awstranscribe",
|
||||||
|
@ -192,18 +197,18 @@ struct Transcriber {
|
||||||
fn build_packet(payload: &[u8]) -> Vec<u8> {
|
fn build_packet(payload: &[u8]) -> Vec<u8> {
|
||||||
let headers = [
|
let headers = [
|
||||||
Header {
|
Header {
|
||||||
name: Cow::Borrowed(":event-type"),
|
name: ":event-type".into(),
|
||||||
value: Cow::Borrowed("AudioEvent"),
|
value: "AudioEvent".into(),
|
||||||
value_type: 7,
|
value_type: 7,
|
||||||
},
|
},
|
||||||
Header {
|
Header {
|
||||||
name: Cow::Borrowed(":content-type"),
|
name: ":content-type".into(),
|
||||||
value: Cow::Borrowed("application/octet-stream"),
|
value: "application/octet-stream".into(),
|
||||||
value_type: 7,
|
value_type: 7,
|
||||||
},
|
},
|
||||||
Header {
|
Header {
|
||||||
name: Cow::Borrowed(":message-type"),
|
name: ":message-type".into(),
|
||||||
value: Cow::Borrowed("event"),
|
value: "event".into(),
|
||||||
value_type: 7,
|
value_type: 7,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
@ -211,18 +216,6 @@ fn build_packet(payload: &[u8]) -> Vec<u8> {
|
||||||
encode_packet(payload, &headers).expect("foobar")
|
encode_packet(payload, &headers).expect("foobar")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_current_running_time(element: &gst::Element) -> gst::ClockTime {
|
|
||||||
if let Some(clock) = element.get_clock() {
|
|
||||||
if clock.get_time() > element.get_base_time() {
|
|
||||||
clock.get_time() - element.get_base_time()
|
|
||||||
} else {
|
|
||||||
0.into()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
gst::CLOCK_TIME_NONE
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Transcriber {
|
impl Transcriber {
|
||||||
fn set_pad_functions(sinkpad: &gst::Pad, srcpad: &gst::Pad) {
|
fn set_pad_functions(sinkpad: &gst::Pad, srcpad: &gst::Pad) {
|
||||||
sinkpad.set_chain_function(|pad, parent, buffer| {
|
sinkpad.set_chain_function(|pad, parent, buffer| {
|
||||||
|
@ -272,7 +265,7 @@ impl Transcriber {
|
||||||
let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64
|
let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64
|
||||||
- GRANULARITY_MS as u64)
|
- GRANULARITY_MS as u64)
|
||||||
* gst::MSECOND;
|
* gst::MSECOND;
|
||||||
let now = get_current_running_time(element);
|
let now = element.get_current_running_time();
|
||||||
|
|
||||||
while let Some(buf) = state.buffers.front() {
|
while let Some(buf) = state.buffers.front() {
|
||||||
if now - buf.get_pts() > latency {
|
if now - buf.get_pts() > latency {
|
||||||
|
@ -379,8 +372,43 @@ impl Transcriber {
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
Message::Binary(buf) => {
|
Message::Binary(buf) => {
|
||||||
let (_, pkt) = parse_packet(&buf).unwrap();
|
let (_, pkt) = parse_packet(&buf).map_err(|err| {
|
||||||
|
gst_error!(CAT, obj: element, "Failed to parse packet: {}", err);
|
||||||
|
gst_error_msg!(
|
||||||
|
gst::StreamError::Failed,
|
||||||
|
["Failed to parse packet: {}", err]
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let payload = std::str::from_utf8(pkt.payload).unwrap();
|
let payload = std::str::from_utf8(pkt.payload).unwrap();
|
||||||
|
|
||||||
|
if packet_is_exception(&pkt) {
|
||||||
|
let message: ExceptionMessage =
|
||||||
|
serde_json::from_str(&payload).map_err(|err| {
|
||||||
|
gst_error!(
|
||||||
|
CAT,
|
||||||
|
obj: element,
|
||||||
|
"Unexpected exception message: {} ({})",
|
||||||
|
payload,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
gst_error_msg!(
|
||||||
|
gst::StreamError::Failed,
|
||||||
|
["Unexpected exception message: {} ({})", payload, err]
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
gst_error!(
|
||||||
|
CAT,
|
||||||
|
obj: element,
|
||||||
|
"AWS raised an error: {}",
|
||||||
|
message.message
|
||||||
|
);
|
||||||
|
Err(gst_error_msg!(
|
||||||
|
gst::StreamError::Failed,
|
||||||
|
["AWS raised an error: {}", message.message]
|
||||||
|
))?;
|
||||||
|
}
|
||||||
|
|
||||||
let mut transcript: Transcript =
|
let mut transcript: Transcript =
|
||||||
serde_json::from_str(&payload).map_err(|err| {
|
serde_json::from_str(&payload).map_err(|err| {
|
||||||
gst_error_msg!(
|
gst_error_msg!(
|
||||||
|
@ -683,7 +711,7 @@ impl Transcriber {
|
||||||
|
|
||||||
if let Some(buffer) = &buffer {
|
if let Some(buffer) = &buffer {
|
||||||
let running_time = state.in_segment.to_running_time(buffer.get_pts());
|
let running_time = state.in_segment.to_running_time(buffer.get_pts());
|
||||||
let now = get_current_running_time(&element);
|
let now = element.get_current_running_time();
|
||||||
|
|
||||||
if now.is_some() && now < running_time {
|
if now.is_some() && now < running_time {
|
||||||
delay = Some(running_time - now);
|
delay = Some(running_time - now);
|
||||||
|
|
|
@ -126,14 +126,24 @@ fn parse_header(input: &[u8]) -> IResult<&[u8], Header> {
|
||||||
let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?;
|
let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?;
|
||||||
|
|
||||||
let header = Header {
|
let header = Header {
|
||||||
name: Cow::Owned(name.to_string()),
|
name: name.to_string().into(),
|
||||||
value_type,
|
value_type,
|
||||||
value: Cow::Owned(value.to_string()),
|
value: value.to_string().into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((input, header))
|
Ok((input, header))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn packet_is_exception(packet: &Packet) -> bool {
|
||||||
|
for header in &packet.headers {
|
||||||
|
if header.name == ":message-type" && header.value_type == 7 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> {
|
pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> {
|
||||||
let (remainder, prelude) = parse_prelude(input)?;
|
let (remainder, prelude) = parse_prelude(input)?;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue