From be00ae7999d5b17a868ff67d5d2880d12a7df730 Mon Sep 17 00:00:00 2001 From: Mathieu Duponchelle Date: Mon, 2 Dec 2024 19:51:51 +0100 Subject: [PATCH] aws/polly: expose property for overflow control Part-of: --- docs/plugins/gst_plugins_cache.json | 32 +++++++++++++++ net/aws/src/polly/imp.rs | 60 ++++++++++++++++++++++++++--- net/aws/src/polly/mod.rs | 14 +++++++ 3 files changed, 100 insertions(+), 6 deletions(-) diff --git a/docs/plugins/gst_plugins_cache.json b/docs/plugins/gst_plugins_cache.json index 376274c5..1c3abea1 100644 --- a/docs/plugins/gst_plugins_cache.json +++ b/docs/plugins/gst_plugins_cache.json @@ -87,6 +87,18 @@ "type": "GstValueArray", "writable": true }, + "overflow": { + "blurb": "Defines how output audio with a longer duration than input text should be handled", + "conditionally-available": false, + "construct": false, + "construct-only": false, + "controllable": false, + "default": "clip (0)", + "mutable": "ready", + "readable": true, + "type": "GstAwsOverflow", + "writable": true + }, "secret-access-key": { "blurb": "AWS Secret Access Key", "conditionally-available": false, @@ -1372,6 +1384,26 @@ "filename": "gstaws", "license": "MPL", "other-types": { + "GstAwsOverflow": { + "kind": "enum", + "values": [ + { + "desc": "Clip", + "name": "clip", + "value": "0" + }, + { + "desc": "Overlap", + "name": "overlap", + "value": "1" + }, + { + "desc": "Shift", + "name": "shift", + "value": "2" + } + ] + }, "GstAwsPollyEngine": { "kind": "enum", "values": [ diff --git a/net/aws/src/polly/imp.rs b/net/aws/src/polly/imp.rs index 9a53451e..af475236 100644 --- a/net/aws/src/polly/imp.rs +++ b/net/aws/src/polly/imp.rs @@ -21,7 +21,7 @@ use std::sync::Mutex; use std::sync::LazyLock; -use super::{AwsPollyEngine, AwsPollyLanguageCode, AwsPollyVoiceId, CAT}; +use super::{AwsOverflow, AwsPollyEngine, AwsPollyLanguageCode, AwsPollyVoiceId, CAT}; use crate::s3utils::RUNTIME; use anyhow::{anyhow, Error}; @@ -35,6 +35,7 @@ const DEFAULT_ENGINE: AwsPollyEngine = AwsPollyEngine::Neural; const DEFAULT_LANGUAGE_CODE: AwsPollyLanguageCode = AwsPollyLanguageCode::None; const DEFAULT_VOICE_ID: AwsPollyVoiceId = AwsPollyVoiceId::Aria; const DEFAULT_SSML_SET_MAX_DURATION: bool = false; +const DEFAULT_OVERFLOW: AwsOverflow = AwsOverflow::Clip; #[derive(Debug, Clone)] pub(super) struct Settings { @@ -47,6 +48,7 @@ pub(super) struct Settings { voice_id: AwsPollyVoiceId, lexicon_names: gst::Array, ssml_set_max_duration: bool, + overflow: AwsOverflow, } impl Default for Settings { @@ -61,6 +63,7 @@ impl Default for Settings { voice_id: DEFAULT_VOICE_ID, lexicon_names: gst::Array::default(), ssml_set_max_duration: DEFAULT_SSML_SET_MAX_DURATION, + overflow: DEFAULT_OVERFLOW, } } } @@ -158,7 +161,7 @@ impl Polly { } async fn send(&self, inbuf: gst::Buffer) -> Result { - let pts = inbuf + let mut pts = inbuf .pts() .ok_or_else(|| anyhow!("Stream with timestamped buffers required"))?; @@ -227,15 +230,47 @@ impl Polly { })?; let blob = resp.audio_stream.collect().await?; - let mut buf = gst::Buffer::from_slice(blob.into_bytes()); - let mut state = self.state.lock().unwrap(); + let mut bytes = blob.into_bytes(); + + let overflow = self.settings.lock().unwrap().overflow; + + if matches!(overflow, AwsOverflow::Clip) { + let max_expected_bytes = duration + .nseconds() + .mul_div_floor(32_000, 1_000_000_000) + .unwrap() + / 2 + * 2; + + gst::debug!( + CAT, + "Received {} bytes, max expected {}", + bytes.len(), + max_expected_bytes + ); + + bytes.truncate(max_expected_bytes as usize); + } let duration = gst::ClockTime::from_nseconds( - (buf.size() as u64) + (bytes.len() as u64) .mul_div_round(1_000_000_000, 32_000) .unwrap(), ); + let mut buf = gst::Buffer::from_slice(bytes); + let mut state = self.state.lock().unwrap(); + + if let Some(position) = state.out_segment.position() { + if matches!(overflow, AwsOverflow::Shift) && pts < position { + gst::debug!( + CAT, + "received pts {pts} < position {position}, shifting forward" + ); + pts = position; + } + } + let discont = state .out_segment .position() @@ -248,7 +283,7 @@ impl Polly { buf_mut.set_duration(duration); if discont { - gst::log!(CAT, imp = self, "Marking buffer discont"); + gst::debug!(CAT, imp = self, "Marking buffer discont"); buf_mut.set_flags(gst::BufferFlags::DISCONT); } inbuf.foreach_meta(|meta| { @@ -521,6 +556,11 @@ impl ObjectImpl for Polly { .default_value(DEFAULT_SSML_SET_MAX_DURATION) .mutable_ready() .build(), + glib::ParamSpecEnum::builder_with_default("overflow", DEFAULT_OVERFLOW) + .nick("Overflow") + .blurb("Defines how output audio with a longer duration than input text should be handled") + .mutable_ready() + .build(), ] }); @@ -581,6 +621,10 @@ impl ObjectImpl for Polly { let mut settings = self.settings.lock().unwrap(); settings.ssml_set_max_duration = value.get().expect("type checked upstream"); } + "overflow" => { + let mut settings = self.settings.lock().unwrap(); + settings.overflow = value.get::().expect("type checked upstream"); + } _ => unimplemented!(), } } @@ -623,6 +667,10 @@ impl ObjectImpl for Polly { let settings = self.settings.lock().unwrap(); settings.ssml_set_max_duration.to_value() } + "overflow" => { + let settings = self.settings.lock().unwrap(); + settings.overflow.to_value() + } _ => unimplemented!(), } } diff --git a/net/aws/src/polly/mod.rs b/net/aws/src/polly/mod.rs index 945522f1..eee0b706 100644 --- a/net/aws/src/polly/mod.rs +++ b/net/aws/src/polly/mod.rs @@ -398,6 +398,19 @@ impl From for LanguageCode { } } +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)] +#[repr(u32)] +#[enum_type(name = "GstAwsOverflow")] +#[non_exhaustive] +pub enum AwsOverflow { + #[enum_value(name = "Clip", nick = "clip")] + Clip = 0, + #[enum_value(name = "Overlap", nick = "overlap")] + Overlap = 1, + #[enum_value(name = "Shift", nick = "shift")] + Shift = 2, +} + glib::wrapper! { pub struct Polly(ObjectSubclass) @extends gst::Element, gst::Object; } @@ -408,6 +421,7 @@ pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { AwsPollyEngine::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); AwsPollyVoiceId::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); AwsPollyLanguageCode::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); + AwsOverflow::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); } gst::Element::register( Some(plugin),