awstranslate: unit test accumulator logic

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/2221>
This commit is contained in:
Mathieu Duponchelle 2025-05-01 20:01:27 +02:00 committed by GStreamer Marge Bot
parent 144aeb615a
commit 492ef81010

View file

@ -76,20 +76,85 @@ struct InputItem {
pub struct InputItems(Vec<InputItem>);
impl InputItems {
fn start_rtime(&self) -> gst::ClockTime {
self.0.first().unwrap().rtime
fn start_rtime(&self) -> Option<gst::ClockTime> {
self.0.first().map(|item| item.rtime)
}
fn start_pts(&self) -> gst::ClockTime {
self.0.first().unwrap().pts
fn start_pts(&self) -> Option<gst::ClockTime> {
self.0.first().map(|item| item.pts)
}
fn discont(&self) -> bool {
self.0.first().unwrap().discont
self.0.first().map(|item| item.discont).unwrap_or(false)
}
fn end_pts(&self) -> gst::ClockTime {
self.0.last().unwrap().end_pts
fn end_pts(&self) -> Option<gst::ClockTime> {
self.0.last().map(|item| item.end_pts)
}
fn is_empty(&self) -> bool {
self.0.is_empty()
}
fn push(&mut self, item: InputItem) -> Result<(), Error> {
if item.discont && !self.is_empty() {
return Err(anyhow!("can't push discont item on non-empty accumulator"));
}
self.0.push(item);
Ok(())
}
fn drain(&mut self, up_to_punctuation: bool) -> Self {
let items = match up_to_punctuation {
true => {
if let Some(punctuation_index) = self.0.iter().rposition(|item| item.is_punctuation)
{
let (items, trailing) = self.0.split_at(punctuation_index + 1);
let items = items.to_vec();
self.0 = trailing.to_vec();
gst::log!(CAT, "drained up to punctuation: {items:?}");
items
} else {
gst::log!(CAT, "drained all items: {:?}", self.0);
self.0.drain(..).collect()
}
}
false => {
gst::log!(CAT, "drained all items: {:?}", self.0);
self.0.drain(..).collect()
}
};
Self(items)
}
fn timeout(&mut self, now: gst::ClockTime, upstream_min: gst::ClockTime) -> Option<Self> {
if let Some(start_rtime) = self.start_rtime() {
if start_rtime + upstream_min < now {
gst::debug!(
CAT,
"draining on timeout: {start_rtime} + {upstream_min} < {now}",
);
Some(self.drain(true))
} else {
gst::trace!(
CAT,
"queued content is not late: {start_rtime} + {upstream_min} >= {now}"
);
None
}
} else {
gst::trace!(CAT, "no queued content, cannot be late");
None
}
}
}
@ -110,7 +175,7 @@ struct State {
// (live, min, max)
upstream_latency: Option<(bool, gst::ClockTime, Option<gst::ClockTime>)>,
segment: Option<gst::FormattedSegment<gst::ClockTime>>,
accumulator: Option<InputItems>,
accumulator: InputItems,
client: Option<aws_sdk_translate::Client>,
send_abort_handle: Option<AbortHandle>,
translate_tx: Option<mpsc::Sender<TranslateInput>>,
@ -125,7 +190,7 @@ impl Default for State {
Self {
upstream_latency: None,
segment: None,
accumulator: None,
accumulator: InputItems(vec![]),
client: None,
send_abort_handle: None,
translate_tx: None,
@ -254,7 +319,7 @@ impl Translate {
let (pts, duration) = gap.get();
if state.accumulator.is_none() {
if state.accumulator.is_empty() {
if let Some(translate_tx) = state.translate_tx.as_ref() {
let _ = translate_tx.send(TranslateInput::Gap { pts, duration });
}
@ -266,8 +331,8 @@ impl Translate {
let translate_tx = self.state.lock().unwrap().translate_tx.take();
if let Some(translate_tx) = translate_tx {
gst::debug!(CAT, imp = self, "received EOS, draining");
let _ = translate_tx.send(TranslateInput::Items(self.drain(false)));
let items = self.state.lock().unwrap().accumulator.drain(false);
let _ = translate_tx.send(TranslateInput::Items(items));
}
true
@ -292,7 +357,7 @@ impl Translate {
};
if drain {
let items = self.drain(false);
let items = self.state.lock().unwrap().accumulator.drain(false);
if let Some(translate_tx) = self.state.lock().unwrap().translate_tx.as_ref() {
let _ = translate_tx.send(TranslateInput::Items(items));
@ -331,39 +396,22 @@ impl Translate {
}
loop {
let start_rtime = self
let to_translate = self
.state
.lock()
.unwrap()
.accumulator
.as_ref()
.map(|accumulator| accumulator.start_rtime());
if let Some(start_rtime) = start_rtime {
if start_rtime + upstream_min < now {
gst::debug!(
CAT,
imp = self,
"draining on timeout: {start_rtime} + {upstream_min} < {now}",
);
self.do_send(self.drain(true))?;
.timeout(now, upstream_min);
if let Some(to_translate) = to_translate {
self.do_send(to_translate)?;
} else {
gst::trace!(
CAT,
imp = self,
"queued content is not late: {start_rtime} + {upstream_min} >= {now}"
);
return Ok(gst::FlowSuccess::Ok);
}
} else {
gst::trace!(CAT, imp = self, "no queued content, cannot be late");
return Ok(gst::FlowSuccess::Ok);
}
}
}
fn do_send(&self, to_translate: InputItems) -> Result<gst::FlowSuccess, gst::FlowError> {
if to_translate.0.is_empty() {
if to_translate.is_empty() {
gst::trace!(CAT, imp = self, "nothing to send, returning early");
return Ok(gst::FlowSuccess::Ok);
}
@ -402,43 +450,6 @@ impl Translate {
}
}
fn drain(&self, up_to_punctuation: bool) -> InputItems {
let mut state = self.state.lock().unwrap();
let Some(accumulator) = state.accumulator.take() else {
gst::trace!(CAT, imp = self, "accumulator is empty");
return InputItems(vec![]);
};
let items = match up_to_punctuation {
true => {
if let Some(punctuation_index) =
accumulator.0.iter().rposition(|item| item.is_punctuation)
{
let (items, trailing) = accumulator.0.split_at(punctuation_index + 1);
if !trailing.is_empty() {
state.accumulator = Some(InputItems(trailing.to_vec()));
}
gst::log!(CAT, imp = self, "drained up to punctuation: {items:?}");
InputItems(items.to_vec())
} else {
gst::log!(CAT, imp = self, "drained all items: {accumulator:?}");
accumulator
}
}
false => {
gst::log!(CAT, imp = self, "drained all items: {accumulator:?}");
accumulator
}
};
items
}
async fn send(&self, to_translate: InputItems) -> Result<Vec<TranslateOutput>, Error> {
let (input_lang, output_lang, latency, tokenization_method) = {
let settings = self.settings.lock().unwrap();
@ -457,7 +468,7 @@ impl Translate {
.segment
.as_mut()
.unwrap()
.set_position(Some(to_translate.end_pts()));
.set_position(to_translate.end_pts());
(
state.client.as_ref().unwrap().clone(),
@ -746,7 +757,7 @@ impl Translate {
})?;
let drained_items = if buffer.flags().contains(gst::BufferFlags::DISCONT) {
let items = self.drain(false);
let items = self.state.lock().unwrap().accumulator.drain(false);
gst::log!(CAT, imp = self, "draining on discont");
@ -800,18 +811,8 @@ impl Translate {
discont,
};
if let Some(accumulator) = state.accumulator.as_mut() {
gst::log!(CAT, imp = self, "queuing item on accumulator: {item:?}");
accumulator.0.push(item);
} else {
gst::log!(
CAT,
imp = self,
"creating new accumulator with item: {item:?}"
);
state.accumulator = Some(InputItems(vec![item]));
}
state.accumulator.push(item).unwrap();
state.chained_one = true;
gst::trace!(
@ -1150,3 +1151,156 @@ impl ElementImpl for Translate {
Some(gst::SystemClock::obtain())
}
}
#[cfg(test)]
mod tests {
use super::{InputItem, InputItems};
#[test]
fn accumulator_basic() {
let mut accumulator = InputItems(vec![]);
assert!(accumulator.is_empty());
assert_eq!(accumulator.start_rtime(), None);
assert_eq!(accumulator.start_pts(), None);
assert_eq!(accumulator.discont(), false);
assert_eq!(accumulator.end_pts(), None);
assert!(accumulator.drain(false).is_empty());
assert!(accumulator
.push(InputItem {
content: "0".into(),
pts: gst::ClockTime::from_nseconds(0),
rtime: gst::ClockTime::from_nseconds(0),
end_pts: gst::ClockTime::from_nseconds(1),
is_punctuation: false,
discont: true
})
.is_ok());
assert!(accumulator
.push(InputItem {
content: "2".into(),
pts: gst::ClockTime::from_nseconds(2),
rtime: gst::ClockTime::from_nseconds(2),
end_pts: gst::ClockTime::from_nseconds(3),
is_punctuation: false,
discont: false
})
.is_ok());
assert!(accumulator
.push(InputItem {
content: "10".into(),
pts: gst::ClockTime::from_nseconds(10),
rtime: gst::ClockTime::from_nseconds(20),
end_pts: gst::ClockTime::from_nseconds(10),
is_punctuation: false,
discont: true
})
.is_err());
assert!(!accumulator.is_empty());
assert_eq!(
accumulator.start_rtime(),
Some(gst::ClockTime::from_nseconds(0))
);
assert_eq!(
accumulator.start_pts(),
Some(gst::ClockTime::from_nseconds(0))
);
assert_eq!(accumulator.discont(), true);
assert_eq!(
accumulator.end_pts(),
Some(gst::ClockTime::from_nseconds(3))
);
assert!(!accumulator.drain(false).is_empty());
}
#[test]
fn test_accumulator_timeout() {
let mut accumulator = InputItems(vec![
InputItem {
content: "0".into(),
pts: gst::ClockTime::from_nseconds(0),
rtime: gst::ClockTime::from_nseconds(0),
end_pts: gst::ClockTime::from_nseconds(1),
is_punctuation: false,
discont: true,
},
InputItem {
content: "2".into(),
pts: gst::ClockTime::from_nseconds(2),
rtime: gst::ClockTime::from_nseconds(2),
end_pts: gst::ClockTime::from_nseconds(3),
is_punctuation: false,
discont: false,
},
]);
let upstream_min = gst::ClockTime::from_nseconds(5);
assert!(accumulator
.timeout(gst::ClockTime::from_nseconds(5), upstream_min)
.is_none());
assert_eq!(
accumulator
.timeout(gst::ClockTime::from_nseconds(6), upstream_min)
.unwrap()
.0
.len(),
2
);
assert!(accumulator.is_empty());
}
#[test]
fn test_accumulator_timeout_punctuation() {
let mut accumulator = InputItems(vec![
InputItem {
content: "0".into(),
pts: gst::ClockTime::from_nseconds(0),
rtime: gst::ClockTime::from_nseconds(0),
end_pts: gst::ClockTime::from_nseconds(1),
is_punctuation: false,
discont: true,
},
InputItem {
content: ".".into(),
pts: gst::ClockTime::from_nseconds(2),
rtime: gst::ClockTime::from_nseconds(2),
end_pts: gst::ClockTime::from_nseconds(3),
is_punctuation: true,
discont: false,
},
InputItem {
content: "5".into(),
pts: gst::ClockTime::from_nseconds(5),
rtime: gst::ClockTime::from_nseconds(5),
end_pts: gst::ClockTime::from_nseconds(6),
is_punctuation: false,
discont: false,
},
]);
let upstream_min = gst::ClockTime::from_nseconds(5);
assert!(accumulator
.timeout(gst::ClockTime::from_nseconds(5), upstream_min)
.is_none());
assert_eq!(
accumulator
.timeout(gst::ClockTime::from_nseconds(6), upstream_min)
.unwrap()
.0
.len(),
2
);
assert_eq!(accumulator.0.len(), 1);
}
}