mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2025-09-02 17:53:48 +00:00
awstranslate: unit test accumulator logic
Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/2221>
This commit is contained in:
parent
144aeb615a
commit
492ef81010
1 changed files with 240 additions and 86 deletions
|
@ -76,20 +76,85 @@ struct InputItem {
|
||||||
pub struct InputItems(Vec<InputItem>);
|
pub struct InputItems(Vec<InputItem>);
|
||||||
|
|
||||||
impl InputItems {
|
impl InputItems {
|
||||||
fn start_rtime(&self) -> gst::ClockTime {
|
fn start_rtime(&self) -> Option<gst::ClockTime> {
|
||||||
self.0.first().unwrap().rtime
|
self.0.first().map(|item| item.rtime)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_pts(&self) -> gst::ClockTime {
|
fn start_pts(&self) -> Option<gst::ClockTime> {
|
||||||
self.0.first().unwrap().pts
|
self.0.first().map(|item| item.pts)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn discont(&self) -> bool {
|
fn discont(&self) -> bool {
|
||||||
self.0.first().unwrap().discont
|
self.0.first().map(|item| item.discont).unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn end_pts(&self) -> gst::ClockTime {
|
fn end_pts(&self) -> Option<gst::ClockTime> {
|
||||||
self.0.last().unwrap().end_pts
|
self.0.last().map(|item| item.end_pts)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_empty(&self) -> bool {
|
||||||
|
self.0.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push(&mut self, item: InputItem) -> Result<(), Error> {
|
||||||
|
if item.discont && !self.is_empty() {
|
||||||
|
return Err(anyhow!("can't push discont item on non-empty accumulator"));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.0.push(item);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drain(&mut self, up_to_punctuation: bool) -> Self {
|
||||||
|
let items = match up_to_punctuation {
|
||||||
|
true => {
|
||||||
|
if let Some(punctuation_index) = self.0.iter().rposition(|item| item.is_punctuation)
|
||||||
|
{
|
||||||
|
let (items, trailing) = self.0.split_at(punctuation_index + 1);
|
||||||
|
|
||||||
|
let items = items.to_vec();
|
||||||
|
|
||||||
|
self.0 = trailing.to_vec();
|
||||||
|
|
||||||
|
gst::log!(CAT, "drained up to punctuation: {items:?}");
|
||||||
|
|
||||||
|
items
|
||||||
|
} else {
|
||||||
|
gst::log!(CAT, "drained all items: {:?}", self.0);
|
||||||
|
|
||||||
|
self.0.drain(..).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false => {
|
||||||
|
gst::log!(CAT, "drained all items: {:?}", self.0);
|
||||||
|
|
||||||
|
self.0.drain(..).collect()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Self(items)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn timeout(&mut self, now: gst::ClockTime, upstream_min: gst::ClockTime) -> Option<Self> {
|
||||||
|
if let Some(start_rtime) = self.start_rtime() {
|
||||||
|
if start_rtime + upstream_min < now {
|
||||||
|
gst::debug!(
|
||||||
|
CAT,
|
||||||
|
"draining on timeout: {start_rtime} + {upstream_min} < {now}",
|
||||||
|
);
|
||||||
|
Some(self.drain(true))
|
||||||
|
} else {
|
||||||
|
gst::trace!(
|
||||||
|
CAT,
|
||||||
|
"queued content is not late: {start_rtime} + {upstream_min} >= {now}"
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gst::trace!(CAT, "no queued content, cannot be late");
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +175,7 @@ struct State {
|
||||||
// (live, min, max)
|
// (live, min, max)
|
||||||
upstream_latency: Option<(bool, gst::ClockTime, Option<gst::ClockTime>)>,
|
upstream_latency: Option<(bool, gst::ClockTime, Option<gst::ClockTime>)>,
|
||||||
segment: Option<gst::FormattedSegment<gst::ClockTime>>,
|
segment: Option<gst::FormattedSegment<gst::ClockTime>>,
|
||||||
accumulator: Option<InputItems>,
|
accumulator: InputItems,
|
||||||
client: Option<aws_sdk_translate::Client>,
|
client: Option<aws_sdk_translate::Client>,
|
||||||
send_abort_handle: Option<AbortHandle>,
|
send_abort_handle: Option<AbortHandle>,
|
||||||
translate_tx: Option<mpsc::Sender<TranslateInput>>,
|
translate_tx: Option<mpsc::Sender<TranslateInput>>,
|
||||||
|
@ -125,7 +190,7 @@ impl Default for State {
|
||||||
Self {
|
Self {
|
||||||
upstream_latency: None,
|
upstream_latency: None,
|
||||||
segment: None,
|
segment: None,
|
||||||
accumulator: None,
|
accumulator: InputItems(vec![]),
|
||||||
client: None,
|
client: None,
|
||||||
send_abort_handle: None,
|
send_abort_handle: None,
|
||||||
translate_tx: None,
|
translate_tx: None,
|
||||||
|
@ -254,7 +319,7 @@ impl Translate {
|
||||||
|
|
||||||
let (pts, duration) = gap.get();
|
let (pts, duration) = gap.get();
|
||||||
|
|
||||||
if state.accumulator.is_none() {
|
if state.accumulator.is_empty() {
|
||||||
if let Some(translate_tx) = state.translate_tx.as_ref() {
|
if let Some(translate_tx) = state.translate_tx.as_ref() {
|
||||||
let _ = translate_tx.send(TranslateInput::Gap { pts, duration });
|
let _ = translate_tx.send(TranslateInput::Gap { pts, duration });
|
||||||
}
|
}
|
||||||
|
@ -266,8 +331,8 @@ impl Translate {
|
||||||
let translate_tx = self.state.lock().unwrap().translate_tx.take();
|
let translate_tx = self.state.lock().unwrap().translate_tx.take();
|
||||||
if let Some(translate_tx) = translate_tx {
|
if let Some(translate_tx) = translate_tx {
|
||||||
gst::debug!(CAT, imp = self, "received EOS, draining");
|
gst::debug!(CAT, imp = self, "received EOS, draining");
|
||||||
|
let items = self.state.lock().unwrap().accumulator.drain(false);
|
||||||
let _ = translate_tx.send(TranslateInput::Items(self.drain(false)));
|
let _ = translate_tx.send(TranslateInput::Items(items));
|
||||||
}
|
}
|
||||||
|
|
||||||
true
|
true
|
||||||
|
@ -292,7 +357,7 @@ impl Translate {
|
||||||
};
|
};
|
||||||
|
|
||||||
if drain {
|
if drain {
|
||||||
let items = self.drain(false);
|
let items = self.state.lock().unwrap().accumulator.drain(false);
|
||||||
|
|
||||||
if let Some(translate_tx) = self.state.lock().unwrap().translate_tx.as_ref() {
|
if let Some(translate_tx) = self.state.lock().unwrap().translate_tx.as_ref() {
|
||||||
let _ = translate_tx.send(TranslateInput::Items(items));
|
let _ = translate_tx.send(TranslateInput::Items(items));
|
||||||
|
@ -331,39 +396,22 @@ impl Translate {
|
||||||
}
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let start_rtime = self
|
let to_translate = self
|
||||||
.state
|
.state
|
||||||
.lock()
|
.lock()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.accumulator
|
.accumulator
|
||||||
.as_ref()
|
.timeout(now, upstream_min);
|
||||||
.map(|accumulator| accumulator.start_rtime());
|
if let Some(to_translate) = to_translate {
|
||||||
|
self.do_send(to_translate)?;
|
||||||
if let Some(start_rtime) = start_rtime {
|
|
||||||
if start_rtime + upstream_min < now {
|
|
||||||
gst::debug!(
|
|
||||||
CAT,
|
|
||||||
imp = self,
|
|
||||||
"draining on timeout: {start_rtime} + {upstream_min} < {now}",
|
|
||||||
);
|
|
||||||
self.do_send(self.drain(true))?;
|
|
||||||
} else {
|
|
||||||
gst::trace!(
|
|
||||||
CAT,
|
|
||||||
imp = self,
|
|
||||||
"queued content is not late: {start_rtime} + {upstream_min} >= {now}"
|
|
||||||
);
|
|
||||||
return Ok(gst::FlowSuccess::Ok);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
gst::trace!(CAT, imp = self, "no queued content, cannot be late");
|
|
||||||
return Ok(gst::FlowSuccess::Ok);
|
return Ok(gst::FlowSuccess::Ok);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn do_send(&self, to_translate: InputItems) -> Result<gst::FlowSuccess, gst::FlowError> {
|
fn do_send(&self, to_translate: InputItems) -> Result<gst::FlowSuccess, gst::FlowError> {
|
||||||
if to_translate.0.is_empty() {
|
if to_translate.is_empty() {
|
||||||
gst::trace!(CAT, imp = self, "nothing to send, returning early");
|
gst::trace!(CAT, imp = self, "nothing to send, returning early");
|
||||||
return Ok(gst::FlowSuccess::Ok);
|
return Ok(gst::FlowSuccess::Ok);
|
||||||
}
|
}
|
||||||
|
@ -402,43 +450,6 @@ impl Translate {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn drain(&self, up_to_punctuation: bool) -> InputItems {
|
|
||||||
let mut state = self.state.lock().unwrap();
|
|
||||||
|
|
||||||
let Some(accumulator) = state.accumulator.take() else {
|
|
||||||
gst::trace!(CAT, imp = self, "accumulator is empty");
|
|
||||||
return InputItems(vec![]);
|
|
||||||
};
|
|
||||||
|
|
||||||
let items = match up_to_punctuation {
|
|
||||||
true => {
|
|
||||||
if let Some(punctuation_index) =
|
|
||||||
accumulator.0.iter().rposition(|item| item.is_punctuation)
|
|
||||||
{
|
|
||||||
let (items, trailing) = accumulator.0.split_at(punctuation_index + 1);
|
|
||||||
|
|
||||||
if !trailing.is_empty() {
|
|
||||||
state.accumulator = Some(InputItems(trailing.to_vec()));
|
|
||||||
}
|
|
||||||
|
|
||||||
gst::log!(CAT, imp = self, "drained up to punctuation: {items:?}");
|
|
||||||
|
|
||||||
InputItems(items.to_vec())
|
|
||||||
} else {
|
|
||||||
gst::log!(CAT, imp = self, "drained all items: {accumulator:?}");
|
|
||||||
accumulator
|
|
||||||
}
|
|
||||||
}
|
|
||||||
false => {
|
|
||||||
gst::log!(CAT, imp = self, "drained all items: {accumulator:?}");
|
|
||||||
|
|
||||||
accumulator
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
items
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send(&self, to_translate: InputItems) -> Result<Vec<TranslateOutput>, Error> {
|
async fn send(&self, to_translate: InputItems) -> Result<Vec<TranslateOutput>, Error> {
|
||||||
let (input_lang, output_lang, latency, tokenization_method) = {
|
let (input_lang, output_lang, latency, tokenization_method) = {
|
||||||
let settings = self.settings.lock().unwrap();
|
let settings = self.settings.lock().unwrap();
|
||||||
|
@ -457,7 +468,7 @@ impl Translate {
|
||||||
.segment
|
.segment
|
||||||
.as_mut()
|
.as_mut()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.set_position(Some(to_translate.end_pts()));
|
.set_position(to_translate.end_pts());
|
||||||
|
|
||||||
(
|
(
|
||||||
state.client.as_ref().unwrap().clone(),
|
state.client.as_ref().unwrap().clone(),
|
||||||
|
@ -746,7 +757,7 @@ impl Translate {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let drained_items = if buffer.flags().contains(gst::BufferFlags::DISCONT) {
|
let drained_items = if buffer.flags().contains(gst::BufferFlags::DISCONT) {
|
||||||
let items = self.drain(false);
|
let items = self.state.lock().unwrap().accumulator.drain(false);
|
||||||
|
|
||||||
gst::log!(CAT, imp = self, "draining on discont");
|
gst::log!(CAT, imp = self, "draining on discont");
|
||||||
|
|
||||||
|
@ -800,18 +811,8 @@ impl Translate {
|
||||||
discont,
|
discont,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(accumulator) = state.accumulator.as_mut() {
|
gst::log!(CAT, imp = self, "queuing item on accumulator: {item:?}");
|
||||||
gst::log!(CAT, imp = self, "queuing item on accumulator: {item:?}");
|
state.accumulator.push(item).unwrap();
|
||||||
accumulator.0.push(item);
|
|
||||||
} else {
|
|
||||||
gst::log!(
|
|
||||||
CAT,
|
|
||||||
imp = self,
|
|
||||||
"creating new accumulator with item: {item:?}"
|
|
||||||
);
|
|
||||||
state.accumulator = Some(InputItems(vec![item]));
|
|
||||||
}
|
|
||||||
|
|
||||||
state.chained_one = true;
|
state.chained_one = true;
|
||||||
|
|
||||||
gst::trace!(
|
gst::trace!(
|
||||||
|
@ -1150,3 +1151,156 @@ impl ElementImpl for Translate {
|
||||||
Some(gst::SystemClock::obtain())
|
Some(gst::SystemClock::obtain())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{InputItem, InputItems};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn accumulator_basic() {
|
||||||
|
let mut accumulator = InputItems(vec![]);
|
||||||
|
|
||||||
|
assert!(accumulator.is_empty());
|
||||||
|
assert_eq!(accumulator.start_rtime(), None);
|
||||||
|
assert_eq!(accumulator.start_pts(), None);
|
||||||
|
assert_eq!(accumulator.discont(), false);
|
||||||
|
assert_eq!(accumulator.end_pts(), None);
|
||||||
|
assert!(accumulator.drain(false).is_empty());
|
||||||
|
|
||||||
|
assert!(accumulator
|
||||||
|
.push(InputItem {
|
||||||
|
content: "0".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(0),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(0),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(1),
|
||||||
|
is_punctuation: false,
|
||||||
|
discont: true
|
||||||
|
})
|
||||||
|
.is_ok());
|
||||||
|
|
||||||
|
assert!(accumulator
|
||||||
|
.push(InputItem {
|
||||||
|
content: "2".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(2),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(2),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(3),
|
||||||
|
is_punctuation: false,
|
||||||
|
discont: false
|
||||||
|
})
|
||||||
|
.is_ok());
|
||||||
|
|
||||||
|
assert!(accumulator
|
||||||
|
.push(InputItem {
|
||||||
|
content: "10".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(10),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(20),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(10),
|
||||||
|
is_punctuation: false,
|
||||||
|
discont: true
|
||||||
|
})
|
||||||
|
.is_err());
|
||||||
|
|
||||||
|
assert!(!accumulator.is_empty());
|
||||||
|
assert_eq!(
|
||||||
|
accumulator.start_rtime(),
|
||||||
|
Some(gst::ClockTime::from_nseconds(0))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
accumulator.start_pts(),
|
||||||
|
Some(gst::ClockTime::from_nseconds(0))
|
||||||
|
);
|
||||||
|
assert_eq!(accumulator.discont(), true);
|
||||||
|
assert_eq!(
|
||||||
|
accumulator.end_pts(),
|
||||||
|
Some(gst::ClockTime::from_nseconds(3))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(!accumulator.drain(false).is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulator_timeout() {
|
||||||
|
let mut accumulator = InputItems(vec![
|
||||||
|
InputItem {
|
||||||
|
content: "0".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(0),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(0),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(1),
|
||||||
|
is_punctuation: false,
|
||||||
|
discont: true,
|
||||||
|
},
|
||||||
|
InputItem {
|
||||||
|
content: "2".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(2),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(2),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(3),
|
||||||
|
is_punctuation: false,
|
||||||
|
discont: false,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
let upstream_min = gst::ClockTime::from_nseconds(5);
|
||||||
|
|
||||||
|
assert!(accumulator
|
||||||
|
.timeout(gst::ClockTime::from_nseconds(5), upstream_min)
|
||||||
|
.is_none());
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
accumulator
|
||||||
|
.timeout(gst::ClockTime::from_nseconds(6), upstream_min)
|
||||||
|
.unwrap()
|
||||||
|
.0
|
||||||
|
.len(),
|
||||||
|
2
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(accumulator.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulator_timeout_punctuation() {
|
||||||
|
let mut accumulator = InputItems(vec![
|
||||||
|
InputItem {
|
||||||
|
content: "0".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(0),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(0),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(1),
|
||||||
|
is_punctuation: false,
|
||||||
|
discont: true,
|
||||||
|
},
|
||||||
|
InputItem {
|
||||||
|
content: ".".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(2),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(2),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(3),
|
||||||
|
is_punctuation: true,
|
||||||
|
discont: false,
|
||||||
|
},
|
||||||
|
InputItem {
|
||||||
|
content: "5".into(),
|
||||||
|
pts: gst::ClockTime::from_nseconds(5),
|
||||||
|
rtime: gst::ClockTime::from_nseconds(5),
|
||||||
|
end_pts: gst::ClockTime::from_nseconds(6),
|
||||||
|
is_punctuation: false,
|
||||||
|
discont: false,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
let upstream_min = gst::ClockTime::from_nseconds(5);
|
||||||
|
|
||||||
|
assert!(accumulator
|
||||||
|
.timeout(gst::ClockTime::from_nseconds(5), upstream_min)
|
||||||
|
.is_none());
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
accumulator
|
||||||
|
.timeout(gst::ClockTime::from_nseconds(6), upstream_min)
|
||||||
|
.unwrap()
|
||||||
|
.0
|
||||||
|
.len(),
|
||||||
|
2
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(accumulator.0.len(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue