gst-plugin-transcribe: address review comments

This commit is contained in:
Mathieu Duponchelle 2020-04-10 19:52:28 +02:00 committed by Sebastian Dröge
parent a2b3b70f3b
commit a31b3c5c83
4 changed files with 192 additions and 158 deletions

View file

@ -1,18 +1,16 @@
[package] [package]
name = "gst-plugin-transcribe" name = "gst-plugin-transcribe"
version = "0.1.0" version = "0.1.0"
authors = ["Jordan Petridis <jordan@centricular.com>"] authors = ["Jordan Petridis <jordan@centricular.com>", "Mathieu Duponchelle <mathieu@centricular.com>"]
edition = "2018" edition = "2018"
description = "AWS Transcribe plugin" description = "AWS Transcribe plugin"
repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs" repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs"
# FIXME: licence license = "LGPL-2.1+"
[dependencies] [dependencies]
glib = { git = "https://github.com/gtk-rs/glib" } glib = { git = "https://github.com/gtk-rs/glib" }
# FIXME: gstreamer = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" }
# gst = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"], package = "gstreamer" } gstreamer-base = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" }
gstreamer = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"] }
gst_base = { git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", features = ["v1_16"], package ="gstreamer-base" }
rusoto_core = "0.43.0-beta.1" rusoto_core = "0.43.0-beta.1"
rusoto_credential = "0.43.0-beta.1" rusoto_credential = "0.43.0-beta.1"
rusoto_transcribe = "0.43.0-beta.1" rusoto_transcribe = "0.43.0-beta.1"
@ -24,7 +22,7 @@ async-tungstenite = { version = "0.4", features = ["tokio", "tokio-runtime", "to
nom = "5.1.1" nom = "5.1.1"
crc = "1.8.1" crc = "1.8.1"
byteorder = "1.3.4" byteorder = "1.3.4"
lazy_static = "1.4.0" once_cell = "1.0"
serde = "1" serde = "1"
serde_derive = "1" serde_derive = "1"
serde_json = "1" serde_json = "1"

View file

@ -1,4 +1,19 @@
#![allow(unused)] // Copyright (C) 2020 Mathieu Duponchelle <mathieu@centricular.com>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Library General Public License for more details.
//
// You should have received a copy of the GNU Library General Public
// License along with this library; if not, write to the
// Free Software Foundation, Inc., 51 Franklin Street, Suite 500,
// Boston, MA 02110-1335, USA.
use glib; use glib;
use glib::prelude::*; use glib::prelude::*;
@ -8,45 +23,32 @@ use gst;
use gst::prelude::*; use gst::prelude::*;
use gst::subclass::prelude::*; use gst::subclass::prelude::*;
use std::convert::TryInto;
use std::default::Default; use std::default::Default;
use std::env;
use rusoto_core::Region; use rusoto_core::Region;
use rusoto_credential; use rusoto_credential;
use rusoto_credential::{EnvironmentProvider, ProvideAwsCredentials}; use rusoto_credential::{EnvironmentProvider, ProvideAwsCredentials};
use rusoto_transcribe;
use rusoto_signature::signature::SignedRequest; use rusoto_signature::signature::SignedRequest;
use rusoto_signature::signature::SignedRequestPayload;
use rusoto_signature::stream::ByteStream;
use rusoto_transcribe::Media;
use rusoto_transcribe::Settings as TranscriptionSettings;
use rusoto_transcribe::StartTranscriptionJobRequest;
use rusoto_transcribe::Transcribe;
use async_tungstenite::tungstenite::error::Error as WsError; use async_tungstenite::tungstenite::error::Error as WsError;
use async_tungstenite::{tokio::connect_async, tungstenite::Message}; use async_tungstenite::{tokio::connect_async, tungstenite::Message};
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::future::{abortable, AbortHandle}; use futures::future::{abortable, AbortHandle};
use futures::io::{AsyncReadExt, Cursor};
use futures::prelude::*; use futures::prelude::*;
use futures::stream::SplitSink;
use tokio::runtime; use tokio::runtime;
use crc::crc32; use std::borrow::Cow;
use std::boxed::Box;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::pin::Pin; use std::pin::Pin;
use std::sync::{Mutex, MutexGuard}; use std::sync::Mutex;
use std::time::Duration; use std::time::Duration;
use lazy_static;
use crate::packet::*; use crate::packet::*;
use serde_derive::{Deserialize, Serialize}; use serde_derive::Deserialize;
use once_cell::sync::Lazy;
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
#[serde(rename_all = "PascalCase")] #[serde(rename_all = "PascalCase")]
@ -87,21 +89,22 @@ struct Transcript {
transcript: TranscriptTranscript, transcript: TranscriptTranscript,
} }
lazy_static! { static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
static ref CAT: gst::DebugCategory = { gst::DebugCategory::new(
gst::DebugCategory::new( "awstranscribe",
"awstranscribe", gst::DebugColorFlags::empty(),
gst::DebugColorFlags::empty(), Some("AWS Transcribe element"),
Some("AWS Transcribe element"), )
) });
};
static ref RUNTIME: runtime::Runtime = runtime::Builder::new() static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new()
.threaded_scheduler() .threaded_scheduler()
.enable_all() .enable_all()
.core_threads(1) .core_threads(1)
.build() .build()
.unwrap(); .unwrap()
} });
const DEFAULT_LATENCY_MS: u32 = 30000; const DEFAULT_LATENCY_MS: u32 = 30000;
const GRANULARITY_MS: u32 = 100; const GRANULARITY_MS: u32 = 100;
@ -110,7 +113,7 @@ static PROPERTIES: [subclass::Property; 2] = [
subclass::Property("language-code", |name| { subclass::Property("language-code", |name| {
glib::ParamSpec::string( glib::ParamSpec::string(
name, name,
"Language-Code", "Language Code",
"The Language of the Stream, see \ "The Language of the Stream, see \
<https://docs.aws.amazon.com/transcribe/latest/dg/how-streaming-transcription.html> \ <https://docs.aws.amazon.com/transcribe/latest/dg/how-streaming-transcription.html> \
for an up to date list of allowed languages", for an up to date list of allowed languages",
@ -121,7 +124,7 @@ static PROPERTIES: [subclass::Property; 2] = [
subclass::Property("latency", |name| { subclass::Property("latency", |name| {
glib::ParamSpec::uint( glib::ParamSpec::uint(
name, name,
"latency in ms", "Latency",
"Amount of milliseconds to allow AWS transcribe", "Amount of milliseconds to allow AWS transcribe",
GRANULARITY_MS, GRANULARITY_MS,
std::u32::MAX, std::u32::MAX,
@ -176,29 +179,31 @@ impl Default for State {
} }
} }
type WsSink = Pin<Box<dyn Sink<Message, Error = WsError> + Send>>;
struct Transcriber { struct Transcriber {
srcpad: gst::Pad, srcpad: gst::Pad,
sinkpad: gst::Pad, sinkpad: gst::Pad,
settings: Mutex<Settings>, settings: Mutex<Settings>,
state: Mutex<State>, state: Mutex<State>,
ws_sink: Mutex<Option<Pin<Box<dyn Sink<Message, Error = WsError> + Send>>>>, ws_sink: Mutex<Option<WsSink>>,
} }
fn build_packet(payload: &[u8]) -> Vec<u8> { fn build_packet(payload: &[u8]) -> Vec<u8> {
let headers = [ let headers = [
Header { Header {
name: String::from(":event-type"), name: Cow::Borrowed(":event-type"),
value: String::from("AudioEvent"), value: Cow::Borrowed("AudioEvent"),
value_type: 7, value_type: 7,
}, },
Header { Header {
name: String::from(":content-type"), name: Cow::Borrowed(":content-type"),
value: String::from("application/octet-stream"), value: Cow::Borrowed("application/octet-stream"),
value_type: 7, value_type: 7,
}, },
Header { Header {
name: String::from(":message-type"), name: Cow::Borrowed(":message-type"),
value: String::from("event"), value: Cow::Borrowed("event"),
value_type: 7, value_type: 7,
}, },
]; ];
@ -256,32 +261,27 @@ impl Transcriber {
}); });
} }
fn push_gap(&self, element: &gst::Element) -> bool { fn dequeue(&self, element: &gst::Element) -> bool {
/* First, check our pending buffers */ /* First, check our pending buffers */
let mut items: Vec<gst::Buffer> = vec![]; let mut items = vec![];
let (latency, now, mut last_position, send_eos, seqnum) = { let (latency, now, mut last_position, send_eos, seqnum) = {
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
let send_eos = state.send_eos && state.buffers.len() == 0; let send_eos = state.send_eos && state.buffers.is_empty();
let latency: gst::ClockTime = ((self.settings.lock().unwrap().latency_ms as u64 let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64
- GRANULARITY_MS as u64) - GRANULARITY_MS as u64)
* 1000000) * gst::MSECOND;
.into();
let now = get_current_running_time(element); let now = get_current_running_time(element);
loop { while let Some(buf) = state.buffers.front() {
if let Some(buf) = state.buffers.front() { if now - buf.get_pts() > latency {
if now - buf.get_pts() > latency { /* Safe unwrap, we know we have an item */
/* Safe unwrap, we know we have an item */ let buf = state.buffers.pop_front().unwrap();
let buf = state.buffers.pop_front().unwrap(); items.push(buf);
items.push(buf) } else {
} else { break;
break;
}
} }
break;
} }
( (
@ -295,7 +295,8 @@ impl Transcriber {
/* We're EOS, we can pause and exit early */ /* We're EOS, we can pause and exit early */
if send_eos { if send_eos {
self.srcpad.pause_task(); let _ = self.srcpad.pause_task();
return self return self
.srcpad .srcpad
.push_event(gst::Event::new_eos().seqnum(seqnum).build()); .push_event(gst::Event::new_eos().seqnum(seqnum).build());
@ -318,7 +319,7 @@ impl Transcriber {
} }
last_position = buf.get_pts() + buf.get_duration(); last_position = buf.get_pts() + buf.get_duration();
{ {
let mut buf = buf.get_mut().unwrap(); let buf = buf.get_mut().unwrap();
buf.set_pts(buf.get_pts()); buf.set_pts(buf.get_pts());
} }
gst_debug!( gst_debug!(
@ -327,7 +328,7 @@ impl Transcriber {
buf.get_pts(), buf.get_pts(),
buf.get_pts() + buf.get_duration() buf.get_pts() + buf.get_duration()
); );
if !self.srcpad.push(buf).is_ok() { if self.srcpad.push(buf).is_err() {
return false; return false;
} }
} }
@ -371,32 +372,33 @@ impl Transcriber {
Some(msg) => msg, Some(msg) => msg,
/* Sender was closed */ /* Sender was closed */
None => { None => {
self.srcpad.pause_task(); let _ = self.srcpad.pause_task();
return Ok(()); return Ok(());
} }
}; };
match msg { match msg {
Message::Binary(buf) => { Message::Binary(buf) => {
let (data, pkt) = parse_packet(&buf).unwrap(); let (_, pkt) = parse_packet(&buf).unwrap();
let payload = std::str::from_utf8(pkt.payload).unwrap(); let payload = std::str::from_utf8(pkt.payload).unwrap();
let transcript: Transcript = serde_json::from_str(&payload).map_err(|err| { let mut transcript: Transcript =
gst_error_msg!( serde_json::from_str(&payload).map_err(|err| {
gst::StreamError::Failed, gst_error_msg!(
["Unexpected binary message: {} ({})", payload, err] gst::StreamError::Failed,
) ["Unexpected binary message: {} ({})", payload, err]
})?; )
})?;
if transcript.transcript.results.len() > 0 { if !transcript.transcript.results.is_empty() {
let result = &transcript.transcript.results[0]; let mut result = transcript.transcript.results.remove(0);
if !result.is_partial && result.alternatives.len() > 0 { if !result.is_partial && !result.alternatives.is_empty() {
let alternative = &result.alternatives[0]; let alternative = result.alternatives.remove(0);
gst_info!(CAT, obj: element, "Transcript: {}", alternative.transcript); gst_info!(CAT, obj: element, "Transcript: {}", alternative.transcript);
let mut start_time: gst::ClockTime = let mut start_time: gst::ClockTime =
((result.start_time as f64 * 1000000000.0) as u64).into(); ((result.start_time as f64 * 1_000_000_000.0) as u64).into();
let end_time: gst::ClockTime = let end_time: gst::ClockTime =
((result.end_time as f64 * 1000000000.0) as u64).into(); ((result.end_time as f64 * 1_000_000_000.0) as u64).into();
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
let position = state.out_segment.get_position(); let position = state.out_segment.get_position();
@ -414,11 +416,11 @@ impl Transcriber {
} }
let mut buf = gst::Buffer::from_mut_slice( let mut buf = gst::Buffer::from_mut_slice(
alternative.transcript.as_bytes().to_vec(), alternative.transcript.into_bytes(),
); );
{ {
let mut buf = buf.get_mut().unwrap(); let buf = buf.get_mut().unwrap();
if state.discont { if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT); buf.set_flags(gst::BufferFlags::DISCONT);
@ -446,16 +448,18 @@ impl Transcriber {
let future = async move { let future = async move {
match tokio::time::timeout(Duration::from_millis(GRANULARITY_MS.into()), future).await { match tokio::time::timeout(Duration::from_millis(GRANULARITY_MS.into()), future).await {
Err(_) => { Err(_) => {
if !self.push_gap(element) { if !self.dequeue(element) {
gst_info!(CAT, obj: element, "Failed to push gap event, pausing"); gst_info!(CAT, obj: element, "Failed to push gap event, pausing");
self.srcpad.pause_task();
let _ = self.srcpad.pause_task();
} }
Ok(()) Ok(())
} }
Ok(res) => { Ok(res) => {
if !self.push_gap(element) { if !self.dequeue(element) {
gst_info!(CAT, obj: element, "Failed to push gap event, pausing"); gst_info!(CAT, obj: element, "Failed to push gap event, pausing");
self.srcpad.pause_task();
let _ = self.srcpad.pause_task();
} }
res res
} }
@ -480,21 +484,24 @@ impl Transcriber {
Some(element) => element, Some(element) => element,
None => { None => {
if let Some(pad) = pad_weak.upgrade() { if let Some(pad) = pad_weak.upgrade() {
pad.pause_task().unwrap(); let _ = pad.pause_task();
} }
return; return;
} }
}; };
let transcribe = Self::from_instance(&element); let transcribe = Self::from_instance(&element);
transcribe.loop_fn(&element, &mut receiver).map_err(|err| { match transcribe.loop_fn(&element, &mut receiver) {
gst_element_error!( Err(err) => {
&element, gst_element_error!(
gst::StreamError::Failed, &element,
["Streaming failed: {}", err] gst::StreamError::Failed,
); ["Streaming failed: {}", err]
transcribe.srcpad.pause_task().unwrap(); );
}); let _ = transcribe.srcpad.pause_task();
}
Ok(_) => (),
};
}); });
if res.is_err() { if res.is_err() {
return Err(gst_loggable_error!(CAT, "Failed to start pad task")); return Err(gst_loggable_error!(CAT, "Failed to start pad task"));
@ -506,7 +513,7 @@ impl Transcriber {
&self, &self,
_pad: &gst::Pad, _pad: &gst::Pad,
element: &gst::Element, element: &gst::Element,
mode: gst::PadMode, _mode: gst::PadMode,
active: bool, active: bool,
) -> Result<(), gst::LoggableError> { ) -> Result<(), gst::LoggableError> {
if active { if active {
@ -537,7 +544,7 @@ impl Transcriber {
if ret { if ret {
let (_, min, _) = peer_query.get_result(); let (_, min, _) = peer_query.get_result();
let our_latency: gst::ClockTime = let our_latency: gst::ClockTime =
(self.settings.lock().unwrap().latency_ms as u64 * 1000000).into(); self.settings.lock().unwrap().latency_ms as u64 * gst::MSECOND;
q.set(true, our_latency + min, gst::CLOCK_TIME_NONE); q.set(true, our_latency + min, gst::CLOCK_TIME_NONE);
} }
ret ret
@ -547,8 +554,8 @@ impl Transcriber {
let state = self.state.lock().unwrap(); let state = self.state.lock().unwrap();
q.set( q.set(
state state
.in_segment .out_segment
.to_stream_time(state.in_segment.get_position()), .to_stream_time(state.out_segment.get_position()),
); );
true true
} else { } else {
@ -572,18 +579,47 @@ impl Transcriber {
} }
Ok(_) => true, Ok(_) => true,
}, },
EventView::FlushStart(e) => { EventView::FlushStart(_) => {
gst_info!(CAT, obj: element, "Received flush start, disconnecting"); gst_info!(CAT, obj: element, "Received flush start, disconnecting");
self.disconnect(element); match self.disconnect(element) {
let ret = pad.event_default(Some(element), event); Err(err) => {
self.srcpad.stop_task(); element.post_error_message(&err);
ret false
}
Ok(_) => {
let mut ret = pad.event_default(Some(element), event);
match self.srcpad.stop_task() {
Err(err) => {
gst_error!(
CAT,
obj: element,
"Failed to stop srcpad task: {}",
err
);
ret = false;
}
Ok(_) => (),
};
ret
}
}
} }
EventView::FlushStop(e) => { EventView::FlushStop(_) => {
gst_info!(CAT, obj: element, "Received flush stop, restarting task"); gst_info!(CAT, obj: element, "Received flush stop, restarting task");
let ret = pad.event_default(Some(element), event);
self.start_task(element); if pad.event_default(Some(element), event) {
ret match self.start_task(element) {
Err(err) => {
gst_error!(CAT, obj: element, "Failed to start srcpad task: {}", err);
false
}
Ok(_) => true,
}
} else {
false
}
} }
EventView::Segment(e) => { EventView::Segment(e) => {
let segment = match e.get_segment().clone().downcast::<gst::ClockTime>() { let segment = match e.get_segment().clone().downcast::<gst::ClockTime>() {
@ -684,7 +720,7 @@ impl Transcriber {
fn handle_buffer( fn handle_buffer(
&self, &self,
pad: &gst::Pad, _pad: &gst::Pad,
element: &gst::Element, element: &gst::Element,
buffer: Option<gst::Buffer>, buffer: Option<gst::Buffer>,
) -> Result<gst::FlowSuccess, gst::FlowError> { ) -> Result<gst::FlowSuccess, gst::FlowError> {
@ -703,12 +739,10 @@ impl Transcriber {
self.state.lock().unwrap().send_abort_handle = Some(abort_handle); self.state.lock().unwrap().send_abort_handle = Some(abort_handle);
let ret = match RUNTIME.enter(|| futures::executor::block_on(future)) { match RUNTIME.enter(|| futures::executor::block_on(future)) {
Err(err) => Err(gst::FlowError::Flushing), Err(_) => Err(gst::FlowError::Flushing),
Ok(res) => res, Ok(res) => res,
}; }
ret
} }
fn sink_chain( fn sink_chain(
@ -723,7 +757,7 @@ impl Transcriber {
fn ensure_connection(&self, element: &gst::Element) -> Result<(), gst::ErrorMessage> { fn ensure_connection(&self, element: &gst::Element) -> Result<(), gst::ErrorMessage> {
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
if state.connected == true { if state.connected {
return Ok(()); return Ok(());
} }
@ -750,29 +784,6 @@ impl Transcriber {
.as_ref() .as_ref()
.expect("Language code is required"); .expect("Language code is required");
let request = StartTranscriptionJobRequest {
job_execution_settings: None,
language_code: language_code.to_string(),
media: Media {
media_file_uri: None,
},
media_format: Some("wav".to_string()),
media_sample_rate_hertz: Some(sample_rate.into()),
output_bucket_name: None,
output_encryption_kms_key_id: None,
settings: Some(TranscriptionSettings {
channel_identification: None,
max_alternatives: None,
max_speaker_labels: None,
show_alternatives: Some(false),
show_speaker_labels: None,
vocabulary_filter_method: None,
vocabulary_filter_name: None,
vocabulary_name: None,
}),
transcription_job_name: element.get_name().to_string(),
};
let region = Region::UsEast1; let region = Region::UsEast1;
let mut signed = SignedRequest::new( let mut signed = SignedRequest::new(
@ -790,7 +801,7 @@ impl Transcriber {
signed.add_param("sample-rate", &sample_rate.to_string()); signed.add_param("sample-rate", &sample_rate.to_string());
let url = signed.generate_presigned_url(&creds, &std::time::Duration::from_secs(60), true); let url = signed.generate_presigned_url(&creds, &std::time::Duration::from_secs(60), true);
let (mut ws, _) = RUNTIME let (ws, _) = RUNTIME
.enter(|| futures::executor::block_on(connect_async(format!("wss{}", &url[5..])))) .enter(|| futures::executor::block_on(connect_async(format!("wss{}", &url[5..]))))
.map_err(|err| { .map_err(|err| {
gst_error!(CAT, obj: element, "Failed to connect: {}", err); gst_error!(CAT, obj: element, "Failed to connect: {}", err);
@ -803,11 +814,7 @@ impl Transcriber {
let element_weak = element.downgrade(); let element_weak = element.downgrade();
let future = async move { let future = async move {
loop { while let Some(element) = element_weak.upgrade() {
let element = match element_weak.upgrade() {
Some(element) => element,
None => break,
};
let transcribe = Self::from_instance(&element); let transcribe = Self::from_instance(&element);
let msg = match ws_stream.next().await { let msg = match ws_stream.next().await {
Some(msg) => msg, Some(msg) => msg,
@ -834,7 +841,9 @@ impl Transcriber {
let mut sender = transcribe.state.lock().unwrap().sender.clone(); let mut sender = transcribe.state.lock().unwrap().sender.clone();
if let Some(sender) = sender.as_mut() { if let Some(sender) = sender.as_mut() {
sender.send(msg).await; if sender.send(msg).await.is_err() {
break;
}
} }
} }
}; };

View file

@ -1,10 +1,22 @@
// FIXME: add lgpl 2.1 license // Copyright (C) 2020 Mathieu Duponchelle <mathieu@centricular.com>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Library General Public License for more details.
//
// You should have received a copy of the GNU Library General Public
// License along with this library; if not, write to the
// Free Software Foundation, Inc., 51 Franklin Street, Suite 500,
// Boston, MA 02110-1335, USA.
#![crate_type = "cdylib"] #![crate_type = "cdylib"]
#[macro_use]
extern crate lazy_static;
#[macro_use] #[macro_use]
extern crate glib; extern crate glib;
#[macro_use] #[macro_use]

View file

@ -1,30 +1,47 @@
// Copyright (C) 2020 Jordan Petridis <jordan@centricular.com>
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Library General Public License for more details.
//
// You should have received a copy of the GNU Library General Public
// License along with this library; if not, write to the
// Free Software Foundation, Inc., 51 Franklin Street, Suite 500,
// Boston, MA 02110-1335, USA.
use byteorder::{BigEndian, WriteBytesExt}; use byteorder::{BigEndian, WriteBytesExt};
use crc::crc32; use crc::crc32;
use nom::{ use nom::{
self, bytes::complete::take, combinator::map_res, multi::many0, number::complete::be_u16, self, bytes::complete::take, combinator::map_res, multi::many0, number::complete::be_u16,
number::complete::be_u32, number::complete::be_u8, sequence::tuple, IResult, number::complete::be_u32, number::complete::be_u8, sequence::tuple, IResult,
}; };
use std::borrow::Cow;
use std::io::{self, Write}; use std::io::{self, Write};
#[derive(Debug)] #[derive(Debug)]
pub struct Prelude { struct Prelude {
total_bytes: u32, total_bytes: u32,
header_bytes: u32, header_bytes: u32,
prelude_crc: u32, prelude_crc: u32,
} }
#[derive(Debug)] #[derive(Debug)]
// FIXME: make private
pub struct Header { pub struct Header {
pub name: String, pub name: Cow<'static, str>,
pub value_type: u8, pub value_type: u8,
pub value: String, pub value: Cow<'static, str>,
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Packet<'a> { pub struct Packet<'a> {
prelude: Prelude, prelude: Prelude,
pub headers: Vec<Header>, headers: Vec<Header>,
pub payload: &'a [u8], pub payload: &'a [u8],
msg_crc: u32, msg_crc: u32,
} }
@ -109,9 +126,9 @@ fn parse_header(input: &[u8]) -> IResult<&[u8], Header> {
let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?; let (input, value) = map_res(take(value_length), std::str::from_utf8)(input)?;
let header = Header { let header = Header {
name: name.to_string(), name: Cow::Owned(name.to_string()),
value_type, value_type,
value: value.to_string(), value: Cow::Owned(value.to_string()),
}; };
Ok((input, header)) Ok((input, header))
@ -125,7 +142,6 @@ pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> {
let (_, msg_crc) = be_u32(&input[input.len() - 4..])?; let (_, msg_crc) = be_u32(&input[input.len() - 4..])?;
if msg_crc != sum { if msg_crc != sum {
// FIXME: a better errortype than mapres
return Err(nom::Err::Error(( return Err(nom::Err::Error((
b"Prelude CRC doesn't match", b"Prelude CRC doesn't match",
nom::error::ErrorKind::MapRes, nom::error::ErrorKind::MapRes,
@ -134,7 +150,6 @@ pub fn parse_packet(input: &[u8]) -> IResult<&[u8], Packet> {
let (remainder, header_input) = take(prelude.header_bytes)(remainder)?; let (remainder, header_input) = take(prelude.header_bytes)(remainder)?;
let (_, headers) = many0(parse_header)(header_input)?; let (_, headers) = many0(parse_header)(header_input)?;
//dbg!(&headers);
let payload_length = prelude.total_bytes - prelude.header_bytes - 4 - 12; let payload_length = prelude.total_bytes - prelude.header_bytes - 4 - 12;
let (remainder, payload) = take(payload_length)(remainder)?; let (remainder, payload) = take(payload_length)(remainder)?;