mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2025-09-02 09:43:48 +00:00
awstranslate: unit test accumulator logic
Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/2221>
This commit is contained in:
parent
144aeb615a
commit
492ef81010
1 changed files with 240 additions and 86 deletions
|
@ -76,20 +76,85 @@ struct InputItem {
|
|||
pub struct InputItems(Vec<InputItem>);
|
||||
|
||||
impl InputItems {
|
||||
fn start_rtime(&self) -> gst::ClockTime {
|
||||
self.0.first().unwrap().rtime
|
||||
fn start_rtime(&self) -> Option<gst::ClockTime> {
|
||||
self.0.first().map(|item| item.rtime)
|
||||
}
|
||||
|
||||
fn start_pts(&self) -> gst::ClockTime {
|
||||
self.0.first().unwrap().pts
|
||||
fn start_pts(&self) -> Option<gst::ClockTime> {
|
||||
self.0.first().map(|item| item.pts)
|
||||
}
|
||||
|
||||
fn discont(&self) -> bool {
|
||||
self.0.first().unwrap().discont
|
||||
self.0.first().map(|item| item.discont).unwrap_or(false)
|
||||
}
|
||||
|
||||
fn end_pts(&self) -> gst::ClockTime {
|
||||
self.0.last().unwrap().end_pts
|
||||
fn end_pts(&self) -> Option<gst::ClockTime> {
|
||||
self.0.last().map(|item| item.end_pts)
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
fn push(&mut self, item: InputItem) -> Result<(), Error> {
|
||||
if item.discont && !self.is_empty() {
|
||||
return Err(anyhow!("can't push discont item on non-empty accumulator"));
|
||||
}
|
||||
|
||||
self.0.push(item);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn drain(&mut self, up_to_punctuation: bool) -> Self {
|
||||
let items = match up_to_punctuation {
|
||||
true => {
|
||||
if let Some(punctuation_index) = self.0.iter().rposition(|item| item.is_punctuation)
|
||||
{
|
||||
let (items, trailing) = self.0.split_at(punctuation_index + 1);
|
||||
|
||||
let items = items.to_vec();
|
||||
|
||||
self.0 = trailing.to_vec();
|
||||
|
||||
gst::log!(CAT, "drained up to punctuation: {items:?}");
|
||||
|
||||
items
|
||||
} else {
|
||||
gst::log!(CAT, "drained all items: {:?}", self.0);
|
||||
|
||||
self.0.drain(..).collect()
|
||||
}
|
||||
}
|
||||
false => {
|
||||
gst::log!(CAT, "drained all items: {:?}", self.0);
|
||||
|
||||
self.0.drain(..).collect()
|
||||
}
|
||||
};
|
||||
|
||||
Self(items)
|
||||
}
|
||||
|
||||
fn timeout(&mut self, now: gst::ClockTime, upstream_min: gst::ClockTime) -> Option<Self> {
|
||||
if let Some(start_rtime) = self.start_rtime() {
|
||||
if start_rtime + upstream_min < now {
|
||||
gst::debug!(
|
||||
CAT,
|
||||
"draining on timeout: {start_rtime} + {upstream_min} < {now}",
|
||||
);
|
||||
Some(self.drain(true))
|
||||
} else {
|
||||
gst::trace!(
|
||||
CAT,
|
||||
"queued content is not late: {start_rtime} + {upstream_min} >= {now}"
|
||||
);
|
||||
None
|
||||
}
|
||||
} else {
|
||||
gst::trace!(CAT, "no queued content, cannot be late");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,7 +175,7 @@ struct State {
|
|||
// (live, min, max)
|
||||
upstream_latency: Option<(bool, gst::ClockTime, Option<gst::ClockTime>)>,
|
||||
segment: Option<gst::FormattedSegment<gst::ClockTime>>,
|
||||
accumulator: Option<InputItems>,
|
||||
accumulator: InputItems,
|
||||
client: Option<aws_sdk_translate::Client>,
|
||||
send_abort_handle: Option<AbortHandle>,
|
||||
translate_tx: Option<mpsc::Sender<TranslateInput>>,
|
||||
|
@ -125,7 +190,7 @@ impl Default for State {
|
|||
Self {
|
||||
upstream_latency: None,
|
||||
segment: None,
|
||||
accumulator: None,
|
||||
accumulator: InputItems(vec![]),
|
||||
client: None,
|
||||
send_abort_handle: None,
|
||||
translate_tx: None,
|
||||
|
@ -254,7 +319,7 @@ impl Translate {
|
|||
|
||||
let (pts, duration) = gap.get();
|
||||
|
||||
if state.accumulator.is_none() {
|
||||
if state.accumulator.is_empty() {
|
||||
if let Some(translate_tx) = state.translate_tx.as_ref() {
|
||||
let _ = translate_tx.send(TranslateInput::Gap { pts, duration });
|
||||
}
|
||||
|
@ -266,8 +331,8 @@ impl Translate {
|
|||
let translate_tx = self.state.lock().unwrap().translate_tx.take();
|
||||
if let Some(translate_tx) = translate_tx {
|
||||
gst::debug!(CAT, imp = self, "received EOS, draining");
|
||||
|
||||
let _ = translate_tx.send(TranslateInput::Items(self.drain(false)));
|
||||
let items = self.state.lock().unwrap().accumulator.drain(false);
|
||||
let _ = translate_tx.send(TranslateInput::Items(items));
|
||||
}
|
||||
|
||||
true
|
||||
|
@ -292,7 +357,7 @@ impl Translate {
|
|||
};
|
||||
|
||||
if drain {
|
||||
let items = self.drain(false);
|
||||
let items = self.state.lock().unwrap().accumulator.drain(false);
|
||||
|
||||
if let Some(translate_tx) = self.state.lock().unwrap().translate_tx.as_ref() {
|
||||
let _ = translate_tx.send(TranslateInput::Items(items));
|
||||
|
@ -331,39 +396,22 @@ impl Translate {
|
|||
}
|
||||
|
||||
loop {
|
||||
let start_rtime = self
|
||||
let to_translate = self
|
||||
.state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.accumulator
|
||||
.as_ref()
|
||||
.map(|accumulator| accumulator.start_rtime());
|
||||
|
||||
if let Some(start_rtime) = start_rtime {
|
||||
if start_rtime + upstream_min < now {
|
||||
gst::debug!(
|
||||
CAT,
|
||||
imp = self,
|
||||
"draining on timeout: {start_rtime} + {upstream_min} < {now}",
|
||||
);
|
||||
self.do_send(self.drain(true))?;
|
||||
} else {
|
||||
gst::trace!(
|
||||
CAT,
|
||||
imp = self,
|
||||
"queued content is not late: {start_rtime} + {upstream_min} >= {now}"
|
||||
);
|
||||
return Ok(gst::FlowSuccess::Ok);
|
||||
}
|
||||
.timeout(now, upstream_min);
|
||||
if let Some(to_translate) = to_translate {
|
||||
self.do_send(to_translate)?;
|
||||
} else {
|
||||
gst::trace!(CAT, imp = self, "no queued content, cannot be late");
|
||||
return Ok(gst::FlowSuccess::Ok);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn do_send(&self, to_translate: InputItems) -> Result<gst::FlowSuccess, gst::FlowError> {
|
||||
if to_translate.0.is_empty() {
|
||||
if to_translate.is_empty() {
|
||||
gst::trace!(CAT, imp = self, "nothing to send, returning early");
|
||||
return Ok(gst::FlowSuccess::Ok);
|
||||
}
|
||||
|
@ -402,43 +450,6 @@ impl Translate {
|
|||
}
|
||||
}
|
||||
|
||||
fn drain(&self, up_to_punctuation: bool) -> InputItems {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
let Some(accumulator) = state.accumulator.take() else {
|
||||
gst::trace!(CAT, imp = self, "accumulator is empty");
|
||||
return InputItems(vec![]);
|
||||
};
|
||||
|
||||
let items = match up_to_punctuation {
|
||||
true => {
|
||||
if let Some(punctuation_index) =
|
||||
accumulator.0.iter().rposition(|item| item.is_punctuation)
|
||||
{
|
||||
let (items, trailing) = accumulator.0.split_at(punctuation_index + 1);
|
||||
|
||||
if !trailing.is_empty() {
|
||||
state.accumulator = Some(InputItems(trailing.to_vec()));
|
||||
}
|
||||
|
||||
gst::log!(CAT, imp = self, "drained up to punctuation: {items:?}");
|
||||
|
||||
InputItems(items.to_vec())
|
||||
} else {
|
||||
gst::log!(CAT, imp = self, "drained all items: {accumulator:?}");
|
||||
accumulator
|
||||
}
|
||||
}
|
||||
false => {
|
||||
gst::log!(CAT, imp = self, "drained all items: {accumulator:?}");
|
||||
|
||||
accumulator
|
||||
}
|
||||
};
|
||||
|
||||
items
|
||||
}
|
||||
|
||||
async fn send(&self, to_translate: InputItems) -> Result<Vec<TranslateOutput>, Error> {
|
||||
let (input_lang, output_lang, latency, tokenization_method) = {
|
||||
let settings = self.settings.lock().unwrap();
|
||||
|
@ -457,7 +468,7 @@ impl Translate {
|
|||
.segment
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.set_position(Some(to_translate.end_pts()));
|
||||
.set_position(to_translate.end_pts());
|
||||
|
||||
(
|
||||
state.client.as_ref().unwrap().clone(),
|
||||
|
@ -746,7 +757,7 @@ impl Translate {
|
|||
})?;
|
||||
|
||||
let drained_items = if buffer.flags().contains(gst::BufferFlags::DISCONT) {
|
||||
let items = self.drain(false);
|
||||
let items = self.state.lock().unwrap().accumulator.drain(false);
|
||||
|
||||
gst::log!(CAT, imp = self, "draining on discont");
|
||||
|
||||
|
@ -800,18 +811,8 @@ impl Translate {
|
|||
discont,
|
||||
};
|
||||
|
||||
if let Some(accumulator) = state.accumulator.as_mut() {
|
||||
gst::log!(CAT, imp = self, "queuing item on accumulator: {item:?}");
|
||||
accumulator.0.push(item);
|
||||
} else {
|
||||
gst::log!(
|
||||
CAT,
|
||||
imp = self,
|
||||
"creating new accumulator with item: {item:?}"
|
||||
);
|
||||
state.accumulator = Some(InputItems(vec![item]));
|
||||
}
|
||||
|
||||
gst::log!(CAT, imp = self, "queuing item on accumulator: {item:?}");
|
||||
state.accumulator.push(item).unwrap();
|
||||
state.chained_one = true;
|
||||
|
||||
gst::trace!(
|
||||
|
@ -1150,3 +1151,156 @@ impl ElementImpl for Translate {
|
|||
Some(gst::SystemClock::obtain())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{InputItem, InputItems};
|
||||
|
||||
#[test]
|
||||
fn accumulator_basic() {
|
||||
let mut accumulator = InputItems(vec![]);
|
||||
|
||||
assert!(accumulator.is_empty());
|
||||
assert_eq!(accumulator.start_rtime(), None);
|
||||
assert_eq!(accumulator.start_pts(), None);
|
||||
assert_eq!(accumulator.discont(), false);
|
||||
assert_eq!(accumulator.end_pts(), None);
|
||||
assert!(accumulator.drain(false).is_empty());
|
||||
|
||||
assert!(accumulator
|
||||
.push(InputItem {
|
||||
content: "0".into(),
|
||||
pts: gst::ClockTime::from_nseconds(0),
|
||||
rtime: gst::ClockTime::from_nseconds(0),
|
||||
end_pts: gst::ClockTime::from_nseconds(1),
|
||||
is_punctuation: false,
|
||||
discont: true
|
||||
})
|
||||
.is_ok());
|
||||
|
||||
assert!(accumulator
|
||||
.push(InputItem {
|
||||
content: "2".into(),
|
||||
pts: gst::ClockTime::from_nseconds(2),
|
||||
rtime: gst::ClockTime::from_nseconds(2),
|
||||
end_pts: gst::ClockTime::from_nseconds(3),
|
||||
is_punctuation: false,
|
||||
discont: false
|
||||
})
|
||||
.is_ok());
|
||||
|
||||
assert!(accumulator
|
||||
.push(InputItem {
|
||||
content: "10".into(),
|
||||
pts: gst::ClockTime::from_nseconds(10),
|
||||
rtime: gst::ClockTime::from_nseconds(20),
|
||||
end_pts: gst::ClockTime::from_nseconds(10),
|
||||
is_punctuation: false,
|
||||
discont: true
|
||||
})
|
||||
.is_err());
|
||||
|
||||
assert!(!accumulator.is_empty());
|
||||
assert_eq!(
|
||||
accumulator.start_rtime(),
|
||||
Some(gst::ClockTime::from_nseconds(0))
|
||||
);
|
||||
assert_eq!(
|
||||
accumulator.start_pts(),
|
||||
Some(gst::ClockTime::from_nseconds(0))
|
||||
);
|
||||
assert_eq!(accumulator.discont(), true);
|
||||
assert_eq!(
|
||||
accumulator.end_pts(),
|
||||
Some(gst::ClockTime::from_nseconds(3))
|
||||
);
|
||||
|
||||
assert!(!accumulator.drain(false).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accumulator_timeout() {
|
||||
let mut accumulator = InputItems(vec![
|
||||
InputItem {
|
||||
content: "0".into(),
|
||||
pts: gst::ClockTime::from_nseconds(0),
|
||||
rtime: gst::ClockTime::from_nseconds(0),
|
||||
end_pts: gst::ClockTime::from_nseconds(1),
|
||||
is_punctuation: false,
|
||||
discont: true,
|
||||
},
|
||||
InputItem {
|
||||
content: "2".into(),
|
||||
pts: gst::ClockTime::from_nseconds(2),
|
||||
rtime: gst::ClockTime::from_nseconds(2),
|
||||
end_pts: gst::ClockTime::from_nseconds(3),
|
||||
is_punctuation: false,
|
||||
discont: false,
|
||||
},
|
||||
]);
|
||||
|
||||
let upstream_min = gst::ClockTime::from_nseconds(5);
|
||||
|
||||
assert!(accumulator
|
||||
.timeout(gst::ClockTime::from_nseconds(5), upstream_min)
|
||||
.is_none());
|
||||
|
||||
assert_eq!(
|
||||
accumulator
|
||||
.timeout(gst::ClockTime::from_nseconds(6), upstream_min)
|
||||
.unwrap()
|
||||
.0
|
||||
.len(),
|
||||
2
|
||||
);
|
||||
|
||||
assert!(accumulator.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accumulator_timeout_punctuation() {
|
||||
let mut accumulator = InputItems(vec![
|
||||
InputItem {
|
||||
content: "0".into(),
|
||||
pts: gst::ClockTime::from_nseconds(0),
|
||||
rtime: gst::ClockTime::from_nseconds(0),
|
||||
end_pts: gst::ClockTime::from_nseconds(1),
|
||||
is_punctuation: false,
|
||||
discont: true,
|
||||
},
|
||||
InputItem {
|
||||
content: ".".into(),
|
||||
pts: gst::ClockTime::from_nseconds(2),
|
||||
rtime: gst::ClockTime::from_nseconds(2),
|
||||
end_pts: gst::ClockTime::from_nseconds(3),
|
||||
is_punctuation: true,
|
||||
discont: false,
|
||||
},
|
||||
InputItem {
|
||||
content: "5".into(),
|
||||
pts: gst::ClockTime::from_nseconds(5),
|
||||
rtime: gst::ClockTime::from_nseconds(5),
|
||||
end_pts: gst::ClockTime::from_nseconds(6),
|
||||
is_punctuation: false,
|
||||
discont: false,
|
||||
},
|
||||
]);
|
||||
|
||||
let upstream_min = gst::ClockTime::from_nseconds(5);
|
||||
|
||||
assert!(accumulator
|
||||
.timeout(gst::ClockTime::from_nseconds(5), upstream_min)
|
||||
.is_none());
|
||||
|
||||
assert_eq!(
|
||||
accumulator
|
||||
.timeout(gst::ClockTime::from_nseconds(6), upstream_min)
|
||||
.unwrap()
|
||||
.0
|
||||
.len(),
|
||||
2
|
||||
);
|
||||
|
||||
assert_eq!(accumulator.0.len(), 1);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue