mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2025-02-19 14:25:20 +00:00
speechmaticstranscriber: add properties for speaker detection
diarization=speaker can be set to enable speaker detection, and max-speakers can be set to control the maximum number of detected speakers. An event is then forwarded downstream upon speaker changes. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/2055>
This commit is contained in:
parent
9da6dff1a9
commit
2d0effd781
3 changed files with 238 additions and 52 deletions
|
@ -29,6 +29,8 @@ use atomic_refcell::AtomicRefCell;
|
|||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use super::SpeechmaticsTranscriberDiarization;
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
#[allow(dead_code)]
|
||||
struct TranscriptMetadata {
|
||||
|
@ -52,6 +54,8 @@ struct TranscriptAlternative {
|
|||
language: Option<String>,
|
||||
#[serde(default)]
|
||||
tags: Vec<String>,
|
||||
#[serde(default)]
|
||||
speaker: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
|
@ -113,12 +117,36 @@ struct Vocable {
|
|||
sounds_like: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum Diarization {
|
||||
None,
|
||||
Speaker,
|
||||
}
|
||||
|
||||
impl From<SpeechmaticsTranscriberDiarization> for Diarization {
|
||||
fn from(val: SpeechmaticsTranscriberDiarization) -> Self {
|
||||
use SpeechmaticsTranscriberDiarization::*;
|
||||
match val {
|
||||
None => Diarization::None,
|
||||
Speaker => Diarization::Speaker,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
struct SpeakerDiarizationConfig {
|
||||
max_speakers: u32,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
struct TranscriptionConfig {
|
||||
language: String,
|
||||
enable_partials: bool,
|
||||
max_delay: f32,
|
||||
additional_vocab: Vec<Vocable>,
|
||||
diarization: Diarization,
|
||||
speaker_diarization_config: SpeakerDiarizationConfig,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
|
@ -162,6 +190,9 @@ const DEFAULT_MAX_DELAY_MS: u32 = 0;
|
|||
const DEFAULT_LATENESS_MS: u32 = 0;
|
||||
const DEFAULT_JOIN_PUNCTUATION: bool = true;
|
||||
const DEFAULT_ENABLE_LATE_PUNCTUATION_HACK: bool = true;
|
||||
const DEFAULT_DIARIZATION: SpeechmaticsTranscriberDiarization =
|
||||
SpeechmaticsTranscriberDiarization::None;
|
||||
const DEFAULT_MAX_SPEAKERS: u32 = 50;
|
||||
const GRANULARITY_MS: u32 = 100;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -174,6 +205,8 @@ struct Settings {
|
|||
api_key: Option<String>,
|
||||
join_punctuation: bool,
|
||||
enable_late_punctuation_hack: bool,
|
||||
diarization: SpeechmaticsTranscriberDiarization,
|
||||
max_speakers: u32,
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
|
@ -187,6 +220,8 @@ impl Default for Settings {
|
|||
api_key: None,
|
||||
join_punctuation: DEFAULT_JOIN_PUNCTUATION,
|
||||
enable_late_punctuation_hack: DEFAULT_ENABLE_LATE_PUNCTUATION_HACK,
|
||||
diarization: DEFAULT_DIARIZATION,
|
||||
max_speakers: DEFAULT_MAX_SPEAKERS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -196,6 +231,7 @@ struct ItemAccumulator {
|
|||
text: String,
|
||||
start_time: gst::ClockTime,
|
||||
end_time: gst::ClockTime,
|
||||
speaker: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ItemAccumulator> for gst::Buffer {
|
||||
|
@ -290,6 +326,8 @@ impl TranscriberSrcPad {
|
|||
let (latency, now, mut last_position, send_eos, seqnum) = {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
let current_speaker = state.current_speaker.clone();
|
||||
|
||||
if let Some(ref mut accumulator_inner) = state.accumulator {
|
||||
if now.saturating_sub(accumulator_inner.start_time + start_time) + granularity
|
||||
> latency
|
||||
|
@ -303,37 +341,46 @@ impl TranscriberSrcPad {
|
|||
accumulator_inner.start_time,
|
||||
accumulator_inner.end_time
|
||||
);
|
||||
|
||||
if current_speaker != accumulator_inner.speaker {
|
||||
let new_speaker = accumulator_inner.speaker.clone();
|
||||
state.push_speaker(new_speaker);
|
||||
}
|
||||
|
||||
let buf = state.accumulator.take().unwrap().into();
|
||||
state.push_buffer(buf);
|
||||
}
|
||||
}
|
||||
|
||||
let send_eos =
|
||||
state.send_eos && state.buffers.is_empty() && state.accumulator.is_none();
|
||||
state.send_eos && state.output_items.is_empty() && state.accumulator.is_none();
|
||||
|
||||
while let Some(mut buf) = state.buffers.pop_front() {
|
||||
{
|
||||
let buf_mut = buf.make_mut();
|
||||
let mut pts = buf_mut.pts().unwrap() + start_time;
|
||||
let mut duration = buf_mut.duration().unwrap();
|
||||
if let Some(position) = state.out_segment.position() {
|
||||
if pts < position {
|
||||
gst::debug!(
|
||||
CAT,
|
||||
imp = self,
|
||||
"Adjusting item timing({:?} < {:?})",
|
||||
pts,
|
||||
position,
|
||||
);
|
||||
duration = duration.saturating_sub(position - pts);
|
||||
pts = position;
|
||||
while let Some(item) = state.output_items.pop_front() {
|
||||
match item {
|
||||
TranscriberOutput::Buffer(mut buf) => {
|
||||
let buf_mut = buf.make_mut();
|
||||
let mut pts = buf_mut.pts().unwrap() + start_time;
|
||||
let mut duration = buf_mut.duration().unwrap();
|
||||
if let Some(position) = state.out_segment.position() {
|
||||
if pts < position {
|
||||
gst::debug!(
|
||||
CAT,
|
||||
imp = self,
|
||||
"Adjusting item timing({:?} < {:?})",
|
||||
pts,
|
||||
position,
|
||||
);
|
||||
duration = duration.saturating_sub(position - pts);
|
||||
pts = position;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf_mut.set_pts(pts);
|
||||
buf_mut.set_duration(duration);
|
||||
buf_mut.set_pts(pts);
|
||||
buf_mut.set_duration(duration);
|
||||
items.push(TranscriberOutput::Buffer(buf));
|
||||
}
|
||||
_ => items.push(item),
|
||||
}
|
||||
items.push(buf);
|
||||
}
|
||||
|
||||
(
|
||||
|
@ -354,33 +401,40 @@ impl TranscriberSrcPad {
|
|||
.push_event(gst::event::Eos::builder().seqnum(seqnum).build());
|
||||
}
|
||||
|
||||
for buf in items.drain(..) {
|
||||
let pts = buf.pts().unwrap();
|
||||
for item in items.drain(..) {
|
||||
match item {
|
||||
TranscriberOutput::Buffer(buf) => {
|
||||
let pts = buf.pts().unwrap();
|
||||
|
||||
if let Some(last_position) = last_position {
|
||||
if pts > last_position {
|
||||
let gap_event = gst::event::Gap::builder(last_position)
|
||||
.duration(pts - last_position)
|
||||
.seqnum(seqnum)
|
||||
.build();
|
||||
gst::log!(CAT, "Pushing gap: {} -> {}", last_position, pts);
|
||||
if !self.obj().push_event(gap_event) {
|
||||
if let Some(last_position) = last_position {
|
||||
if pts > last_position {
|
||||
let gap_event = gst::event::Gap::builder(last_position)
|
||||
.duration(pts - last_position)
|
||||
.seqnum(seqnum)
|
||||
.build();
|
||||
gst::log!(CAT, "Pushing gap: {} -> {}", last_position, pts);
|
||||
if !self.obj().push_event(gap_event) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pts_end = if let Some(duration) = buf.duration() {
|
||||
pts + duration
|
||||
} else {
|
||||
pts
|
||||
};
|
||||
last_position = Some(pts_end);
|
||||
|
||||
gst::debug!(CAT, imp = self, "Pushing buffer: {} -> {}", pts, pts_end,);
|
||||
|
||||
if self.obj().push(buf).is_err() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pts_end = if let Some(duration) = buf.duration() {
|
||||
pts + duration
|
||||
} else {
|
||||
pts
|
||||
};
|
||||
last_position = Some(pts_end);
|
||||
|
||||
gst::debug!(CAT, imp = self, "Pushing buffer: {} -> {}", pts, pts_end,);
|
||||
|
||||
if self.obj().push(buf).is_err() {
|
||||
return false;
|
||||
TranscriberOutput::Event(event) => {
|
||||
let _ = self.obj().push_event(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -481,12 +535,18 @@ impl TranscriberSrcPad {
|
|||
gst::debug!(
|
||||
CAT,
|
||||
imp = self,
|
||||
"Item is ready: \"{}\", start_time: {}, end_time: {}",
|
||||
"Item is ready: \"{}\", start_time: {}, end_time: {}, speaker: {:?}",
|
||||
accumulator_inner.text,
|
||||
accumulator_inner.start_time,
|
||||
accumulator_inner.end_time
|
||||
accumulator_inner.end_time,
|
||||
accumulator_inner.speaker
|
||||
);
|
||||
|
||||
if state.current_speaker != accumulator_inner.speaker {
|
||||
let new_speaker = accumulator_inner.speaker.clone();
|
||||
state.push_speaker(new_speaker);
|
||||
}
|
||||
|
||||
let buffer = state.accumulator.take().unwrap().into();
|
||||
|
||||
state.push_buffer(buffer);
|
||||
|
@ -494,6 +554,7 @@ impl TranscriberSrcPad {
|
|||
text: alternative.content.clone(),
|
||||
start_time,
|
||||
end_time,
|
||||
speaker: alternative.speaker.clone(),
|
||||
});
|
||||
}
|
||||
} else if join_punctuation {
|
||||
|
@ -501,6 +562,7 @@ impl TranscriberSrcPad {
|
|||
text: alternative.content.clone(),
|
||||
start_time,
|
||||
end_time,
|
||||
speaker: alternative.speaker.clone(),
|
||||
});
|
||||
} else {
|
||||
let text = alternative.content.clone();
|
||||
|
@ -508,12 +570,17 @@ impl TranscriberSrcPad {
|
|||
gst::debug!(
|
||||
CAT,
|
||||
imp = self,
|
||||
"Item is ready: \"{}\", start_time: {}, end_time: {}",
|
||||
"Item is ready: \"{}\", start_time: {}, end_time: {}, speaker: {:?}",
|
||||
text,
|
||||
start_time,
|
||||
end_time
|
||||
end_time,
|
||||
alternative.speaker,
|
||||
);
|
||||
|
||||
if state.current_speaker != alternative.speaker {
|
||||
state.push_speaker(alternative.speaker.clone());
|
||||
}
|
||||
|
||||
let mut buf = gst::Buffer::from_slice(text.into_bytes());
|
||||
|
||||
{
|
||||
|
@ -1295,6 +1362,10 @@ impl Transcriber {
|
|||
enable_partials: false,
|
||||
max_delay,
|
||||
additional_vocab: state.additional_vocabulary.clone(),
|
||||
diarization: settings.diarization.into(),
|
||||
speaker_diarization_config: SpeakerDiarizationConfig {
|
||||
max_speakers: settings.max_speakers,
|
||||
},
|
||||
},
|
||||
translation_config: TranslationConfig {
|
||||
target_languages: translation_languages,
|
||||
|
@ -1304,7 +1375,7 @@ impl Transcriber {
|
|||
|
||||
let message = serde_json::to_string(&start_message).unwrap();
|
||||
|
||||
gst::trace!(CAT, imp = self, "Sending start message: {}", message);
|
||||
gst::debug!(CAT, imp = self, "Sending start message: {}", message);
|
||||
|
||||
RUNTIME
|
||||
.block_on(ws_sink.send(Message::Text(message)))
|
||||
|
@ -1619,6 +1690,16 @@ impl ObjectImpl for Transcriber {
|
|||
.default_value(DEFAULT_ENABLE_LATE_PUNCTUATION_HACK)
|
||||
.mutable_ready()
|
||||
.build(),
|
||||
glib::ParamSpecEnum::builder_with_default("diarization", DEFAULT_DIARIZATION)
|
||||
.nick("Diarization")
|
||||
.blurb("Defines how to separate speakers in the audio")
|
||||
.mutable_ready()
|
||||
.build(),
|
||||
glib::ParamSpecUInt::builder("max-speakers")
|
||||
.nick("Max Speakers")
|
||||
.blurb("The maximum number of speakers that may be detected with diarization=speaker")
|
||||
.default_value(DEFAULT_MAX_SPEAKERS)
|
||||
.build(),
|
||||
]
|
||||
});
|
||||
|
||||
|
@ -1746,6 +1827,16 @@ impl ObjectImpl for Transcriber {
|
|||
let mut settings = self.settings.lock().unwrap();
|
||||
settings.enable_late_punctuation_hack = value.get().expect("type checked upstream");
|
||||
}
|
||||
"diarization" => {
|
||||
let mut settings = self.settings.lock().unwrap();
|
||||
settings.diarization = value
|
||||
.get::<SpeechmaticsTranscriberDiarization>()
|
||||
.expect("type checked upstream");
|
||||
}
|
||||
"max-speakers" => {
|
||||
let mut settings = self.settings.lock().unwrap();
|
||||
settings.max_speakers = value.get().expect("type checked upstream");
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
@ -1801,6 +1892,8 @@ impl ObjectImpl for Transcriber {
|
|||
let settings = self.settings.lock().unwrap();
|
||||
settings.enable_late_punctuation_hack.to_value()
|
||||
}
|
||||
"diarization" => self.settings.lock().unwrap().diarization.to_value(),
|
||||
"max-speakers" => self.settings.lock().unwrap().max_speakers.to_value(),
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
@ -1999,16 +2092,23 @@ struct TranscriberSrcPadSettings {
|
|||
language_code: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TranscriberOutput {
|
||||
Buffer(gst::Buffer),
|
||||
Event(gst::Event),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TranscriberSrcPadState {
|
||||
sender: Option<mpsc::Sender<Message>>,
|
||||
accumulator: Option<ItemAccumulator>,
|
||||
buffers: VecDeque<gst::Buffer>,
|
||||
output_items: VecDeque<TranscriberOutput>,
|
||||
discont: bool,
|
||||
send_eos: bool,
|
||||
out_segment: gst::FormattedSegment<gst::ClockTime>,
|
||||
seqnum: gst::Seqnum,
|
||||
unsynced_pad: Option<gst::Pad>,
|
||||
current_speaker: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for TranscriberSrcPadState {
|
||||
|
@ -2016,12 +2116,13 @@ impl Default for TranscriberSrcPadState {
|
|||
Self {
|
||||
sender: None,
|
||||
accumulator: None,
|
||||
buffers: VecDeque::new(),
|
||||
output_items: VecDeque::new(),
|
||||
discont: true,
|
||||
send_eos: false,
|
||||
out_segment: gst::FormattedSegment::new(),
|
||||
seqnum: gst::Seqnum::next(),
|
||||
unsynced_pad: None,
|
||||
current_speaker: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2040,7 +2141,24 @@ impl TranscriberSrcPadState {
|
|||
self.discont = false;
|
||||
}
|
||||
|
||||
self.buffers.push_back(buf);
|
||||
self.output_items.push_back(TranscriberOutput::Buffer(buf));
|
||||
}
|
||||
|
||||
fn push_event(&mut self, event: gst::Event) {
|
||||
self.output_items.push_back(TranscriberOutput::Event(event));
|
||||
}
|
||||
|
||||
fn push_speaker(&mut self, speaker: Option<String>) {
|
||||
let event = gst::event::CustomDownstream::builder(
|
||||
gst::Structure::builder("rstranscribe/speaker-change")
|
||||
.field("speaker", &speaker)
|
||||
.build(),
|
||||
)
|
||||
.build();
|
||||
|
||||
self.current_speaker = speaker;
|
||||
|
||||
self.push_event(event);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,10 +19,23 @@ glib::wrapper! {
|
|||
pub struct TranscriberSrcPad(ObjectSubclass<imp::TranscriberSrcPad>) @extends gst::Pad, gst::Object;
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)]
|
||||
#[repr(u32)]
|
||||
#[enum_type(name = "GstSpeechmaticsTranscriberDiarization")]
|
||||
#[non_exhaustive]
|
||||
pub enum SpeechmaticsTranscriberDiarization {
|
||||
#[enum_value(name = "None: no diarization", nick = "none")]
|
||||
None = 0,
|
||||
#[enum_value(name = "Speaker: identify speakers by their voices", nick = "speaker")]
|
||||
Speaker = 1,
|
||||
}
|
||||
|
||||
pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
|
||||
#[cfg(feature = "doc")]
|
||||
{
|
||||
TranscriberSrcPad::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
|
||||
SpeechmaticsTranscriberDiarization::static_type()
|
||||
.mark_as_plugin_api(gst::PluginAPIFlags::empty());
|
||||
}
|
||||
gst::Element::register(
|
||||
Some(plugin),
|
||||
|
|
|
@ -14099,6 +14099,18 @@
|
|||
"type": "gchararray",
|
||||
"writable": true
|
||||
},
|
||||
"diarization": {
|
||||
"blurb": "Defines how to separate speakers in the audio",
|
||||
"conditionally-available": false,
|
||||
"construct": false,
|
||||
"construct-only": false,
|
||||
"controllable": false,
|
||||
"default": "none (0)",
|
||||
"mutable": "ready",
|
||||
"readable": true,
|
||||
"type": "GstSpeechmaticsTranscriberDiarization",
|
||||
"writable": true
|
||||
},
|
||||
"enable-late-punctuation-hack": {
|
||||
"blurb": "Pass a reduced max-delay to speechmatics to make sure we always get punctuation in time for joining it with the preceding word.",
|
||||
"conditionally-available": false,
|
||||
|
@ -14163,6 +14175,34 @@
|
|||
"type": "guint",
|
||||
"writable": true
|
||||
},
|
||||
"max-delay": {
|
||||
"blurb": "Max delay to pass to the speechmatics API (0 = use latency)",
|
||||
"conditionally-available": false,
|
||||
"construct": false,
|
||||
"construct-only": false,
|
||||
"controllable": false,
|
||||
"default": "0",
|
||||
"max": "-1",
|
||||
"min": "0",
|
||||
"mutable": "null",
|
||||
"readable": true,
|
||||
"type": "guint",
|
||||
"writable": true
|
||||
},
|
||||
"max-speakers": {
|
||||
"blurb": "The maximum number of speakers that may be detected with diarization=speaker",
|
||||
"conditionally-available": false,
|
||||
"construct": false,
|
||||
"construct-only": false,
|
||||
"controllable": false,
|
||||
"default": "50",
|
||||
"max": "-1",
|
||||
"min": "0",
|
||||
"mutable": "null",
|
||||
"readable": true,
|
||||
"type": "guint",
|
||||
"writable": true
|
||||
},
|
||||
"url": {
|
||||
"blurb": "URL of the transcription server",
|
||||
"conditionally-available": false,
|
||||
|
@ -14182,6 +14222,21 @@
|
|||
"filename": "gstspeechmatics",
|
||||
"license": "Proprietary",
|
||||
"other-types": {
|
||||
"GstSpeechmaticsTranscriberDiarization": {
|
||||
"kind": "enum",
|
||||
"values": [
|
||||
{
|
||||
"desc": "None: no diarization",
|
||||
"name": "none",
|
||||
"value": "0"
|
||||
},
|
||||
{
|
||||
"desc": "Speaker: identify speakers by their voices",
|
||||
"name": "speaker",
|
||||
"value": "1"
|
||||
}
|
||||
]
|
||||
},
|
||||
"GstSpeechmaticsTranscriberSrcPad": {
|
||||
"hierarchy": [
|
||||
"GstSpeechmaticsTranscriberSrcPad",
|
||||
|
|
Loading…
Reference in a new issue