aws/polly: expose property for overflow control

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1965>
This commit is contained in:
Mathieu Duponchelle 2024-12-02 19:51:51 +01:00 committed by GStreamer Marge Bot
parent 4852a4a5e6
commit be00ae7999
3 changed files with 100 additions and 6 deletions

View file

@ -87,6 +87,18 @@
"type": "GstValueArray", "type": "GstValueArray",
"writable": true "writable": true
}, },
"overflow": {
"blurb": "Defines how output audio with a longer duration than input text should be handled",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "clip (0)",
"mutable": "ready",
"readable": true,
"type": "GstAwsOverflow",
"writable": true
},
"secret-access-key": { "secret-access-key": {
"blurb": "AWS Secret Access Key", "blurb": "AWS Secret Access Key",
"conditionally-available": false, "conditionally-available": false,
@ -1372,6 +1384,26 @@
"filename": "gstaws", "filename": "gstaws",
"license": "MPL", "license": "MPL",
"other-types": { "other-types": {
"GstAwsOverflow": {
"kind": "enum",
"values": [
{
"desc": "Clip",
"name": "clip",
"value": "0"
},
{
"desc": "Overlap",
"name": "overlap",
"value": "1"
},
{
"desc": "Shift",
"name": "shift",
"value": "2"
}
]
},
"GstAwsPollyEngine": { "GstAwsPollyEngine": {
"kind": "enum", "kind": "enum",
"values": [ "values": [

View file

@ -21,7 +21,7 @@ use std::sync::Mutex;
use std::sync::LazyLock; use std::sync::LazyLock;
use super::{AwsPollyEngine, AwsPollyLanguageCode, AwsPollyVoiceId, CAT}; use super::{AwsOverflow, AwsPollyEngine, AwsPollyLanguageCode, AwsPollyVoiceId, CAT};
use crate::s3utils::RUNTIME; use crate::s3utils::RUNTIME;
use anyhow::{anyhow, Error}; use anyhow::{anyhow, Error};
@ -35,6 +35,7 @@ const DEFAULT_ENGINE: AwsPollyEngine = AwsPollyEngine::Neural;
const DEFAULT_LANGUAGE_CODE: AwsPollyLanguageCode = AwsPollyLanguageCode::None; const DEFAULT_LANGUAGE_CODE: AwsPollyLanguageCode = AwsPollyLanguageCode::None;
const DEFAULT_VOICE_ID: AwsPollyVoiceId = AwsPollyVoiceId::Aria; const DEFAULT_VOICE_ID: AwsPollyVoiceId = AwsPollyVoiceId::Aria;
const DEFAULT_SSML_SET_MAX_DURATION: bool = false; const DEFAULT_SSML_SET_MAX_DURATION: bool = false;
const DEFAULT_OVERFLOW: AwsOverflow = AwsOverflow::Clip;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(super) struct Settings { pub(super) struct Settings {
@ -47,6 +48,7 @@ pub(super) struct Settings {
voice_id: AwsPollyVoiceId, voice_id: AwsPollyVoiceId,
lexicon_names: gst::Array, lexicon_names: gst::Array,
ssml_set_max_duration: bool, ssml_set_max_duration: bool,
overflow: AwsOverflow,
} }
impl Default for Settings { impl Default for Settings {
@ -61,6 +63,7 @@ impl Default for Settings {
voice_id: DEFAULT_VOICE_ID, voice_id: DEFAULT_VOICE_ID,
lexicon_names: gst::Array::default(), lexicon_names: gst::Array::default(),
ssml_set_max_duration: DEFAULT_SSML_SET_MAX_DURATION, ssml_set_max_duration: DEFAULT_SSML_SET_MAX_DURATION,
overflow: DEFAULT_OVERFLOW,
} }
} }
} }
@ -158,7 +161,7 @@ impl Polly {
} }
async fn send(&self, inbuf: gst::Buffer) -> Result<gst::Buffer, Error> { async fn send(&self, inbuf: gst::Buffer) -> Result<gst::Buffer, Error> {
let pts = inbuf let mut pts = inbuf
.pts() .pts()
.ok_or_else(|| anyhow!("Stream with timestamped buffers required"))?; .ok_or_else(|| anyhow!("Stream with timestamped buffers required"))?;
@ -227,15 +230,47 @@ impl Polly {
})?; })?;
let blob = resp.audio_stream.collect().await?; let blob = resp.audio_stream.collect().await?;
let mut buf = gst::Buffer::from_slice(blob.into_bytes()); let mut bytes = blob.into_bytes();
let mut state = self.state.lock().unwrap();
let overflow = self.settings.lock().unwrap().overflow;
if matches!(overflow, AwsOverflow::Clip) {
let max_expected_bytes = duration
.nseconds()
.mul_div_floor(32_000, 1_000_000_000)
.unwrap()
/ 2
* 2;
gst::debug!(
CAT,
"Received {} bytes, max expected {}",
bytes.len(),
max_expected_bytes
);
bytes.truncate(max_expected_bytes as usize);
}
let duration = gst::ClockTime::from_nseconds( let duration = gst::ClockTime::from_nseconds(
(buf.size() as u64) (bytes.len() as u64)
.mul_div_round(1_000_000_000, 32_000) .mul_div_round(1_000_000_000, 32_000)
.unwrap(), .unwrap(),
); );
let mut buf = gst::Buffer::from_slice(bytes);
let mut state = self.state.lock().unwrap();
if let Some(position) = state.out_segment.position() {
if matches!(overflow, AwsOverflow::Shift) && pts < position {
gst::debug!(
CAT,
"received pts {pts} < position {position}, shifting forward"
);
pts = position;
}
}
let discont = state let discont = state
.out_segment .out_segment
.position() .position()
@ -248,7 +283,7 @@ impl Polly {
buf_mut.set_duration(duration); buf_mut.set_duration(duration);
if discont { if discont {
gst::log!(CAT, imp = self, "Marking buffer discont"); gst::debug!(CAT, imp = self, "Marking buffer discont");
buf_mut.set_flags(gst::BufferFlags::DISCONT); buf_mut.set_flags(gst::BufferFlags::DISCONT);
} }
inbuf.foreach_meta(|meta| { inbuf.foreach_meta(|meta| {
@ -521,6 +556,11 @@ impl ObjectImpl for Polly {
.default_value(DEFAULT_SSML_SET_MAX_DURATION) .default_value(DEFAULT_SSML_SET_MAX_DURATION)
.mutable_ready() .mutable_ready()
.build(), .build(),
glib::ParamSpecEnum::builder_with_default("overflow", DEFAULT_OVERFLOW)
.nick("Overflow")
.blurb("Defines how output audio with a longer duration than input text should be handled")
.mutable_ready()
.build(),
] ]
}); });
@ -581,6 +621,10 @@ impl ObjectImpl for Polly {
let mut settings = self.settings.lock().unwrap(); let mut settings = self.settings.lock().unwrap();
settings.ssml_set_max_duration = value.get().expect("type checked upstream"); settings.ssml_set_max_duration = value.get().expect("type checked upstream");
} }
"overflow" => {
let mut settings = self.settings.lock().unwrap();
settings.overflow = value.get::<AwsOverflow>().expect("type checked upstream");
}
_ => unimplemented!(), _ => unimplemented!(),
} }
} }
@ -623,6 +667,10 @@ impl ObjectImpl for Polly {
let settings = self.settings.lock().unwrap(); let settings = self.settings.lock().unwrap();
settings.ssml_set_max_duration.to_value() settings.ssml_set_max_duration.to_value()
} }
"overflow" => {
let settings = self.settings.lock().unwrap();
settings.overflow.to_value()
}
_ => unimplemented!(), _ => unimplemented!(),
} }
} }

View file

@ -398,6 +398,19 @@ impl From<AwsPollyLanguageCode> for LanguageCode {
} }
} }
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)]
#[repr(u32)]
#[enum_type(name = "GstAwsOverflow")]
#[non_exhaustive]
pub enum AwsOverflow {
#[enum_value(name = "Clip", nick = "clip")]
Clip = 0,
#[enum_value(name = "Overlap", nick = "overlap")]
Overlap = 1,
#[enum_value(name = "Shift", nick = "shift")]
Shift = 2,
}
glib::wrapper! { glib::wrapper! {
pub struct Polly(ObjectSubclass<imp::Polly>) @extends gst::Element, gst::Object; pub struct Polly(ObjectSubclass<imp::Polly>) @extends gst::Element, gst::Object;
} }
@ -408,6 +421,7 @@ pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
AwsPollyEngine::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); AwsPollyEngine::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
AwsPollyVoiceId::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); AwsPollyVoiceId::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
AwsPollyLanguageCode::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); AwsPollyLanguageCode::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
AwsOverflow::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
} }
gst::Element::register( gst::Element::register(
Some(plugin), Some(plugin),