speechmaticstranscriber: add properties for speaker detection

diarization=speaker can be set to enable speaker detection, and
max-speakers can be set to control the maximum number of detected
speakers.

An event is then forwarded downstream upon speaker changes.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/2055>
This commit is contained in:
Mathieu Duponchelle 2025-02-03 17:27:31 +01:00 committed by GStreamer Marge Bot
parent 9da6dff1a9
commit 2d0effd781
3 changed files with 238 additions and 52 deletions

View file

@ -29,6 +29,8 @@ use atomic_refcell::AtomicRefCell;
use std::sync::LazyLock;
use super::SpeechmaticsTranscriberDiarization;
#[derive(serde::Deserialize, Debug)]
#[allow(dead_code)]
struct TranscriptMetadata {
@ -52,6 +54,8 @@ struct TranscriptAlternative {
language: Option<String>,
#[serde(default)]
tags: Vec<String>,
#[serde(default)]
speaker: Option<String>,
}
#[derive(serde::Deserialize, Debug)]
@ -113,12 +117,36 @@ struct Vocable {
sounds_like: Vec<String>,
}
#[derive(serde::Serialize, Debug)]
#[serde(rename_all = "lowercase")]
enum Diarization {
None,
Speaker,
}
impl From<SpeechmaticsTranscriberDiarization> for Diarization {
fn from(val: SpeechmaticsTranscriberDiarization) -> Self {
use SpeechmaticsTranscriberDiarization::*;
match val {
None => Diarization::None,
Speaker => Diarization::Speaker,
}
}
}
#[derive(serde::Serialize, Debug)]
struct SpeakerDiarizationConfig {
max_speakers: u32,
}
#[derive(serde::Serialize, Debug)]
struct TranscriptionConfig {
language: String,
enable_partials: bool,
max_delay: f32,
additional_vocab: Vec<Vocable>,
diarization: Diarization,
speaker_diarization_config: SpeakerDiarizationConfig,
}
#[derive(serde::Serialize, Debug)]
@ -162,6 +190,9 @@ const DEFAULT_MAX_DELAY_MS: u32 = 0;
const DEFAULT_LATENESS_MS: u32 = 0;
const DEFAULT_JOIN_PUNCTUATION: bool = true;
const DEFAULT_ENABLE_LATE_PUNCTUATION_HACK: bool = true;
const DEFAULT_DIARIZATION: SpeechmaticsTranscriberDiarization =
SpeechmaticsTranscriberDiarization::None;
const DEFAULT_MAX_SPEAKERS: u32 = 50;
const GRANULARITY_MS: u32 = 100;
#[derive(Debug, Clone)]
@ -174,6 +205,8 @@ struct Settings {
api_key: Option<String>,
join_punctuation: bool,
enable_late_punctuation_hack: bool,
diarization: SpeechmaticsTranscriberDiarization,
max_speakers: u32,
}
impl Default for Settings {
@ -187,6 +220,8 @@ impl Default for Settings {
api_key: None,
join_punctuation: DEFAULT_JOIN_PUNCTUATION,
enable_late_punctuation_hack: DEFAULT_ENABLE_LATE_PUNCTUATION_HACK,
diarization: DEFAULT_DIARIZATION,
max_speakers: DEFAULT_MAX_SPEAKERS,
}
}
}
@ -196,6 +231,7 @@ struct ItemAccumulator {
text: String,
start_time: gst::ClockTime,
end_time: gst::ClockTime,
speaker: Option<String>,
}
impl From<ItemAccumulator> for gst::Buffer {
@ -290,6 +326,8 @@ impl TranscriberSrcPad {
let (latency, now, mut last_position, send_eos, seqnum) = {
let mut state = self.state.lock().unwrap();
let current_speaker = state.current_speaker.clone();
if let Some(ref mut accumulator_inner) = state.accumulator {
if now.saturating_sub(accumulator_inner.start_time + start_time) + granularity
> latency
@ -303,37 +341,46 @@ impl TranscriberSrcPad {
accumulator_inner.start_time,
accumulator_inner.end_time
);
if current_speaker != accumulator_inner.speaker {
let new_speaker = accumulator_inner.speaker.clone();
state.push_speaker(new_speaker);
}
let buf = state.accumulator.take().unwrap().into();
state.push_buffer(buf);
}
}
let send_eos =
state.send_eos && state.buffers.is_empty() && state.accumulator.is_none();
state.send_eos && state.output_items.is_empty() && state.accumulator.is_none();
while let Some(mut buf) = state.buffers.pop_front() {
{
let buf_mut = buf.make_mut();
let mut pts = buf_mut.pts().unwrap() + start_time;
let mut duration = buf_mut.duration().unwrap();
if let Some(position) = state.out_segment.position() {
if pts < position {
gst::debug!(
CAT,
imp = self,
"Adjusting item timing({:?} < {:?})",
pts,
position,
);
duration = duration.saturating_sub(position - pts);
pts = position;
while let Some(item) = state.output_items.pop_front() {
match item {
TranscriberOutput::Buffer(mut buf) => {
let buf_mut = buf.make_mut();
let mut pts = buf_mut.pts().unwrap() + start_time;
let mut duration = buf_mut.duration().unwrap();
if let Some(position) = state.out_segment.position() {
if pts < position {
gst::debug!(
CAT,
imp = self,
"Adjusting item timing({:?} < {:?})",
pts,
position,
);
duration = duration.saturating_sub(position - pts);
pts = position;
}
}
}
buf_mut.set_pts(pts);
buf_mut.set_duration(duration);
buf_mut.set_pts(pts);
buf_mut.set_duration(duration);
items.push(TranscriberOutput::Buffer(buf));
}
_ => items.push(item),
}
items.push(buf);
}
(
@ -354,33 +401,40 @@ impl TranscriberSrcPad {
.push_event(gst::event::Eos::builder().seqnum(seqnum).build());
}
for buf in items.drain(..) {
let pts = buf.pts().unwrap();
for item in items.drain(..) {
match item {
TranscriberOutput::Buffer(buf) => {
let pts = buf.pts().unwrap();
if let Some(last_position) = last_position {
if pts > last_position {
let gap_event = gst::event::Gap::builder(last_position)
.duration(pts - last_position)
.seqnum(seqnum)
.build();
gst::log!(CAT, "Pushing gap: {} -> {}", last_position, pts);
if !self.obj().push_event(gap_event) {
if let Some(last_position) = last_position {
if pts > last_position {
let gap_event = gst::event::Gap::builder(last_position)
.duration(pts - last_position)
.seqnum(seqnum)
.build();
gst::log!(CAT, "Pushing gap: {} -> {}", last_position, pts);
if !self.obj().push_event(gap_event) {
return false;
}
}
}
let pts_end = if let Some(duration) = buf.duration() {
pts + duration
} else {
pts
};
last_position = Some(pts_end);
gst::debug!(CAT, imp = self, "Pushing buffer: {} -> {}", pts, pts_end,);
if self.obj().push(buf).is_err() {
return false;
}
}
}
let pts_end = if let Some(duration) = buf.duration() {
pts + duration
} else {
pts
};
last_position = Some(pts_end);
gst::debug!(CAT, imp = self, "Pushing buffer: {} -> {}", pts, pts_end,);
if self.obj().push(buf).is_err() {
return false;
TranscriberOutput::Event(event) => {
let _ = self.obj().push_event(event);
}
}
}
@ -481,12 +535,18 @@ impl TranscriberSrcPad {
gst::debug!(
CAT,
imp = self,
"Item is ready: \"{}\", start_time: {}, end_time: {}",
"Item is ready: \"{}\", start_time: {}, end_time: {}, speaker: {:?}",
accumulator_inner.text,
accumulator_inner.start_time,
accumulator_inner.end_time
accumulator_inner.end_time,
accumulator_inner.speaker
);
if state.current_speaker != accumulator_inner.speaker {
let new_speaker = accumulator_inner.speaker.clone();
state.push_speaker(new_speaker);
}
let buffer = state.accumulator.take().unwrap().into();
state.push_buffer(buffer);
@ -494,6 +554,7 @@ impl TranscriberSrcPad {
text: alternative.content.clone(),
start_time,
end_time,
speaker: alternative.speaker.clone(),
});
}
} else if join_punctuation {
@ -501,6 +562,7 @@ impl TranscriberSrcPad {
text: alternative.content.clone(),
start_time,
end_time,
speaker: alternative.speaker.clone(),
});
} else {
let text = alternative.content.clone();
@ -508,12 +570,17 @@ impl TranscriberSrcPad {
gst::debug!(
CAT,
imp = self,
"Item is ready: \"{}\", start_time: {}, end_time: {}",
"Item is ready: \"{}\", start_time: {}, end_time: {}, speaker: {:?}",
text,
start_time,
end_time
end_time,
alternative.speaker,
);
if state.current_speaker != alternative.speaker {
state.push_speaker(alternative.speaker.clone());
}
let mut buf = gst::Buffer::from_slice(text.into_bytes());
{
@ -1295,6 +1362,10 @@ impl Transcriber {
enable_partials: false,
max_delay,
additional_vocab: state.additional_vocabulary.clone(),
diarization: settings.diarization.into(),
speaker_diarization_config: SpeakerDiarizationConfig {
max_speakers: settings.max_speakers,
},
},
translation_config: TranslationConfig {
target_languages: translation_languages,
@ -1304,7 +1375,7 @@ impl Transcriber {
let message = serde_json::to_string(&start_message).unwrap();
gst::trace!(CAT, imp = self, "Sending start message: {}", message);
gst::debug!(CAT, imp = self, "Sending start message: {}", message);
RUNTIME
.block_on(ws_sink.send(Message::Text(message)))
@ -1619,6 +1690,16 @@ impl ObjectImpl for Transcriber {
.default_value(DEFAULT_ENABLE_LATE_PUNCTUATION_HACK)
.mutable_ready()
.build(),
glib::ParamSpecEnum::builder_with_default("diarization", DEFAULT_DIARIZATION)
.nick("Diarization")
.blurb("Defines how to separate speakers in the audio")
.mutable_ready()
.build(),
glib::ParamSpecUInt::builder("max-speakers")
.nick("Max Speakers")
.blurb("The maximum number of speakers that may be detected with diarization=speaker")
.default_value(DEFAULT_MAX_SPEAKERS)
.build(),
]
});
@ -1746,6 +1827,16 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.enable_late_punctuation_hack = value.get().expect("type checked upstream");
}
"diarization" => {
let mut settings = self.settings.lock().unwrap();
settings.diarization = value
.get::<SpeechmaticsTranscriberDiarization>()
.expect("type checked upstream");
}
"max-speakers" => {
let mut settings = self.settings.lock().unwrap();
settings.max_speakers = value.get().expect("type checked upstream");
}
_ => unimplemented!(),
}
}
@ -1801,6 +1892,8 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
settings.enable_late_punctuation_hack.to_value()
}
"diarization" => self.settings.lock().unwrap().diarization.to_value(),
"max-speakers" => self.settings.lock().unwrap().max_speakers.to_value(),
_ => unimplemented!(),
}
}
@ -1999,16 +2092,23 @@ struct TranscriberSrcPadSettings {
language_code: Option<String>,
}
#[derive(Debug)]
enum TranscriberOutput {
Buffer(gst::Buffer),
Event(gst::Event),
}
#[derive(Debug)]
struct TranscriberSrcPadState {
sender: Option<mpsc::Sender<Message>>,
accumulator: Option<ItemAccumulator>,
buffers: VecDeque<gst::Buffer>,
output_items: VecDeque<TranscriberOutput>,
discont: bool,
send_eos: bool,
out_segment: gst::FormattedSegment<gst::ClockTime>,
seqnum: gst::Seqnum,
unsynced_pad: Option<gst::Pad>,
current_speaker: Option<String>,
}
impl Default for TranscriberSrcPadState {
@ -2016,12 +2116,13 @@ impl Default for TranscriberSrcPadState {
Self {
sender: None,
accumulator: None,
buffers: VecDeque::new(),
output_items: VecDeque::new(),
discont: true,
send_eos: false,
out_segment: gst::FormattedSegment::new(),
seqnum: gst::Seqnum::next(),
unsynced_pad: None,
current_speaker: None,
}
}
}
@ -2040,7 +2141,24 @@ impl TranscriberSrcPadState {
self.discont = false;
}
self.buffers.push_back(buf);
self.output_items.push_back(TranscriberOutput::Buffer(buf));
}
fn push_event(&mut self, event: gst::Event) {
self.output_items.push_back(TranscriberOutput::Event(event));
}
fn push_speaker(&mut self, speaker: Option<String>) {
let event = gst::event::CustomDownstream::builder(
gst::Structure::builder("rstranscribe/speaker-change")
.field("speaker", &speaker)
.build(),
)
.build();
self.current_speaker = speaker;
self.push_event(event);
}
}

View file

@ -19,10 +19,23 @@ glib::wrapper! {
pub struct TranscriberSrcPad(ObjectSubclass<imp::TranscriberSrcPad>) @extends gst::Pad, gst::Object;
}
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)]
#[repr(u32)]
#[enum_type(name = "GstSpeechmaticsTranscriberDiarization")]
#[non_exhaustive]
pub enum SpeechmaticsTranscriberDiarization {
#[enum_value(name = "None: no diarization", nick = "none")]
None = 0,
#[enum_value(name = "Speaker: identify speakers by their voices", nick = "speaker")]
Speaker = 1,
}
pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
#[cfg(feature = "doc")]
{
TranscriberSrcPad::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty());
SpeechmaticsTranscriberDiarization::static_type()
.mark_as_plugin_api(gst::PluginAPIFlags::empty());
}
gst::Element::register(
Some(plugin),

View file

@ -14099,6 +14099,18 @@
"type": "gchararray",
"writable": true
},
"diarization": {
"blurb": "Defines how to separate speakers in the audio",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "none (0)",
"mutable": "ready",
"readable": true,
"type": "GstSpeechmaticsTranscriberDiarization",
"writable": true
},
"enable-late-punctuation-hack": {
"blurb": "Pass a reduced max-delay to speechmatics to make sure we always get punctuation in time for joining it with the preceding word.",
"conditionally-available": false,
@ -14163,6 +14175,34 @@
"type": "guint",
"writable": true
},
"max-delay": {
"blurb": "Max delay to pass to the speechmatics API (0 = use latency)",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "0",
"max": "-1",
"min": "0",
"mutable": "null",
"readable": true,
"type": "guint",
"writable": true
},
"max-speakers": {
"blurb": "The maximum number of speakers that may be detected with diarization=speaker",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "50",
"max": "-1",
"min": "0",
"mutable": "null",
"readable": true,
"type": "guint",
"writable": true
},
"url": {
"blurb": "URL of the transcription server",
"conditionally-available": false,
@ -14182,6 +14222,21 @@
"filename": "gstspeechmatics",
"license": "Proprietary",
"other-types": {
"GstSpeechmaticsTranscriberDiarization": {
"kind": "enum",
"values": [
{
"desc": "None: no diarization",
"name": "none",
"value": "0"
},
{
"desc": "Speaker: identify speakers by their voices",
"name": "speaker",
"value": "1"
}
]
},
"GstSpeechmaticsTranscriberSrcPad": {
"hierarchy": [
"GstSpeechmaticsTranscriberSrcPad",