From 492ef8101039f168f438d01205d20e5aa464b6d7 Mon Sep 17 00:00:00 2001 From: Mathieu Duponchelle Date: Thu, 1 May 2025 20:01:27 +0200 Subject: [PATCH] awstranslate: unit test accumulator logic Part-of: --- net/aws/src/translate/imp.rs | 326 ++++++++++++++++++++++++++--------- 1 file changed, 240 insertions(+), 86 deletions(-) diff --git a/net/aws/src/translate/imp.rs b/net/aws/src/translate/imp.rs index bf40e790a..245ebad8e 100644 --- a/net/aws/src/translate/imp.rs +++ b/net/aws/src/translate/imp.rs @@ -76,20 +76,85 @@ struct InputItem { pub struct InputItems(Vec); impl InputItems { - fn start_rtime(&self) -> gst::ClockTime { - self.0.first().unwrap().rtime + fn start_rtime(&self) -> Option { + self.0.first().map(|item| item.rtime) } - fn start_pts(&self) -> gst::ClockTime { - self.0.first().unwrap().pts + fn start_pts(&self) -> Option { + self.0.first().map(|item| item.pts) } fn discont(&self) -> bool { - self.0.first().unwrap().discont + self.0.first().map(|item| item.discont).unwrap_or(false) } - fn end_pts(&self) -> gst::ClockTime { - self.0.last().unwrap().end_pts + fn end_pts(&self) -> Option { + self.0.last().map(|item| item.end_pts) + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn push(&mut self, item: InputItem) -> Result<(), Error> { + if item.discont && !self.is_empty() { + return Err(anyhow!("can't push discont item on non-empty accumulator")); + } + + self.0.push(item); + + Ok(()) + } + + fn drain(&mut self, up_to_punctuation: bool) -> Self { + let items = match up_to_punctuation { + true => { + if let Some(punctuation_index) = self.0.iter().rposition(|item| item.is_punctuation) + { + let (items, trailing) = self.0.split_at(punctuation_index + 1); + + let items = items.to_vec(); + + self.0 = trailing.to_vec(); + + gst::log!(CAT, "drained up to punctuation: {items:?}"); + + items + } else { + gst::log!(CAT, "drained all items: {:?}", self.0); + + self.0.drain(..).collect() + } + } + false => { + gst::log!(CAT, "drained all items: {:?}", self.0); + + self.0.drain(..).collect() + } + }; + + Self(items) + } + + fn timeout(&mut self, now: gst::ClockTime, upstream_min: gst::ClockTime) -> Option { + if let Some(start_rtime) = self.start_rtime() { + if start_rtime + upstream_min < now { + gst::debug!( + CAT, + "draining on timeout: {start_rtime} + {upstream_min} < {now}", + ); + Some(self.drain(true)) + } else { + gst::trace!( + CAT, + "queued content is not late: {start_rtime} + {upstream_min} >= {now}" + ); + None + } + } else { + gst::trace!(CAT, "no queued content, cannot be late"); + None + } } } @@ -110,7 +175,7 @@ struct State { // (live, min, max) upstream_latency: Option<(bool, gst::ClockTime, Option)>, segment: Option>, - accumulator: Option, + accumulator: InputItems, client: Option, send_abort_handle: Option, translate_tx: Option>, @@ -125,7 +190,7 @@ impl Default for State { Self { upstream_latency: None, segment: None, - accumulator: None, + accumulator: InputItems(vec![]), client: None, send_abort_handle: None, translate_tx: None, @@ -254,7 +319,7 @@ impl Translate { let (pts, duration) = gap.get(); - if state.accumulator.is_none() { + if state.accumulator.is_empty() { if let Some(translate_tx) = state.translate_tx.as_ref() { let _ = translate_tx.send(TranslateInput::Gap { pts, duration }); } @@ -266,8 +331,8 @@ impl Translate { let translate_tx = self.state.lock().unwrap().translate_tx.take(); if let Some(translate_tx) = translate_tx { gst::debug!(CAT, imp = self, "received EOS, draining"); - - let _ = translate_tx.send(TranslateInput::Items(self.drain(false))); + let items = self.state.lock().unwrap().accumulator.drain(false); + let _ = translate_tx.send(TranslateInput::Items(items)); } true @@ -292,7 +357,7 @@ impl Translate { }; if drain { - let items = self.drain(false); + let items = self.state.lock().unwrap().accumulator.drain(false); if let Some(translate_tx) = self.state.lock().unwrap().translate_tx.as_ref() { let _ = translate_tx.send(TranslateInput::Items(items)); @@ -331,39 +396,22 @@ impl Translate { } loop { - let start_rtime = self + let to_translate = self .state .lock() .unwrap() .accumulator - .as_ref() - .map(|accumulator| accumulator.start_rtime()); - - if let Some(start_rtime) = start_rtime { - if start_rtime + upstream_min < now { - gst::debug!( - CAT, - imp = self, - "draining on timeout: {start_rtime} + {upstream_min} < {now}", - ); - self.do_send(self.drain(true))?; - } else { - gst::trace!( - CAT, - imp = self, - "queued content is not late: {start_rtime} + {upstream_min} >= {now}" - ); - return Ok(gst::FlowSuccess::Ok); - } + .timeout(now, upstream_min); + if let Some(to_translate) = to_translate { + self.do_send(to_translate)?; } else { - gst::trace!(CAT, imp = self, "no queued content, cannot be late"); return Ok(gst::FlowSuccess::Ok); } } } fn do_send(&self, to_translate: InputItems) -> Result { - if to_translate.0.is_empty() { + if to_translate.is_empty() { gst::trace!(CAT, imp = self, "nothing to send, returning early"); return Ok(gst::FlowSuccess::Ok); } @@ -402,43 +450,6 @@ impl Translate { } } - fn drain(&self, up_to_punctuation: bool) -> InputItems { - let mut state = self.state.lock().unwrap(); - - let Some(accumulator) = state.accumulator.take() else { - gst::trace!(CAT, imp = self, "accumulator is empty"); - return InputItems(vec![]); - }; - - let items = match up_to_punctuation { - true => { - if let Some(punctuation_index) = - accumulator.0.iter().rposition(|item| item.is_punctuation) - { - let (items, trailing) = accumulator.0.split_at(punctuation_index + 1); - - if !trailing.is_empty() { - state.accumulator = Some(InputItems(trailing.to_vec())); - } - - gst::log!(CAT, imp = self, "drained up to punctuation: {items:?}"); - - InputItems(items.to_vec()) - } else { - gst::log!(CAT, imp = self, "drained all items: {accumulator:?}"); - accumulator - } - } - false => { - gst::log!(CAT, imp = self, "drained all items: {accumulator:?}"); - - accumulator - } - }; - - items - } - async fn send(&self, to_translate: InputItems) -> Result, Error> { let (input_lang, output_lang, latency, tokenization_method) = { let settings = self.settings.lock().unwrap(); @@ -457,7 +468,7 @@ impl Translate { .segment .as_mut() .unwrap() - .set_position(Some(to_translate.end_pts())); + .set_position(to_translate.end_pts()); ( state.client.as_ref().unwrap().clone(), @@ -746,7 +757,7 @@ impl Translate { })?; let drained_items = if buffer.flags().contains(gst::BufferFlags::DISCONT) { - let items = self.drain(false); + let items = self.state.lock().unwrap().accumulator.drain(false); gst::log!(CAT, imp = self, "draining on discont"); @@ -800,18 +811,8 @@ impl Translate { discont, }; - if let Some(accumulator) = state.accumulator.as_mut() { - gst::log!(CAT, imp = self, "queuing item on accumulator: {item:?}"); - accumulator.0.push(item); - } else { - gst::log!( - CAT, - imp = self, - "creating new accumulator with item: {item:?}" - ); - state.accumulator = Some(InputItems(vec![item])); - } - + gst::log!(CAT, imp = self, "queuing item on accumulator: {item:?}"); + state.accumulator.push(item).unwrap(); state.chained_one = true; gst::trace!( @@ -1150,3 +1151,156 @@ impl ElementImpl for Translate { Some(gst::SystemClock::obtain()) } } + +#[cfg(test)] +mod tests { + use super::{InputItem, InputItems}; + + #[test] + fn accumulator_basic() { + let mut accumulator = InputItems(vec![]); + + assert!(accumulator.is_empty()); + assert_eq!(accumulator.start_rtime(), None); + assert_eq!(accumulator.start_pts(), None); + assert_eq!(accumulator.discont(), false); + assert_eq!(accumulator.end_pts(), None); + assert!(accumulator.drain(false).is_empty()); + + assert!(accumulator + .push(InputItem { + content: "0".into(), + pts: gst::ClockTime::from_nseconds(0), + rtime: gst::ClockTime::from_nseconds(0), + end_pts: gst::ClockTime::from_nseconds(1), + is_punctuation: false, + discont: true + }) + .is_ok()); + + assert!(accumulator + .push(InputItem { + content: "2".into(), + pts: gst::ClockTime::from_nseconds(2), + rtime: gst::ClockTime::from_nseconds(2), + end_pts: gst::ClockTime::from_nseconds(3), + is_punctuation: false, + discont: false + }) + .is_ok()); + + assert!(accumulator + .push(InputItem { + content: "10".into(), + pts: gst::ClockTime::from_nseconds(10), + rtime: gst::ClockTime::from_nseconds(20), + end_pts: gst::ClockTime::from_nseconds(10), + is_punctuation: false, + discont: true + }) + .is_err()); + + assert!(!accumulator.is_empty()); + assert_eq!( + accumulator.start_rtime(), + Some(gst::ClockTime::from_nseconds(0)) + ); + assert_eq!( + accumulator.start_pts(), + Some(gst::ClockTime::from_nseconds(0)) + ); + assert_eq!(accumulator.discont(), true); + assert_eq!( + accumulator.end_pts(), + Some(gst::ClockTime::from_nseconds(3)) + ); + + assert!(!accumulator.drain(false).is_empty()); + } + + #[test] + fn test_accumulator_timeout() { + let mut accumulator = InputItems(vec![ + InputItem { + content: "0".into(), + pts: gst::ClockTime::from_nseconds(0), + rtime: gst::ClockTime::from_nseconds(0), + end_pts: gst::ClockTime::from_nseconds(1), + is_punctuation: false, + discont: true, + }, + InputItem { + content: "2".into(), + pts: gst::ClockTime::from_nseconds(2), + rtime: gst::ClockTime::from_nseconds(2), + end_pts: gst::ClockTime::from_nseconds(3), + is_punctuation: false, + discont: false, + }, + ]); + + let upstream_min = gst::ClockTime::from_nseconds(5); + + assert!(accumulator + .timeout(gst::ClockTime::from_nseconds(5), upstream_min) + .is_none()); + + assert_eq!( + accumulator + .timeout(gst::ClockTime::from_nseconds(6), upstream_min) + .unwrap() + .0 + .len(), + 2 + ); + + assert!(accumulator.is_empty()); + } + + #[test] + fn test_accumulator_timeout_punctuation() { + let mut accumulator = InputItems(vec![ + InputItem { + content: "0".into(), + pts: gst::ClockTime::from_nseconds(0), + rtime: gst::ClockTime::from_nseconds(0), + end_pts: gst::ClockTime::from_nseconds(1), + is_punctuation: false, + discont: true, + }, + InputItem { + content: ".".into(), + pts: gst::ClockTime::from_nseconds(2), + rtime: gst::ClockTime::from_nseconds(2), + end_pts: gst::ClockTime::from_nseconds(3), + is_punctuation: true, + discont: false, + }, + InputItem { + content: "5".into(), + pts: gst::ClockTime::from_nseconds(5), + rtime: gst::ClockTime::from_nseconds(5), + end_pts: gst::ClockTime::from_nseconds(6), + is_punctuation: false, + discont: false, + }, + ]); + + let upstream_min = gst::ClockTime::from_nseconds(5); + + assert!(accumulator + .timeout(gst::ClockTime::from_nseconds(5), upstream_min) + .is_none()); + + assert_eq!( + accumulator + .timeout(gst::ClockTime::from_nseconds(6), upstream_min) + .unwrap() + .0 + .len(), + 2 + ); + + assert_eq!(accumulator.0.len(), 1); + } +}