cea708mux: use cea708-types push_code_into_single_service

Simplifies the code significantly as we don't need to track whether a service or
a packet is full or not.

Fixes at least one case of a WouldOverflow panic and a subtraction underflow.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/2275>
This commit is contained in:
Matthew Waters 2025-06-05 17:21:32 +10:00 committed by Sebastian Dröge
parent 5a25e2a12b
commit 3d25763ac9
4 changed files with 171 additions and 105 deletions

4
Cargo.lock generated
View file

@ -1121,9 +1121,9 @@ dependencies = [
[[package]]
name = "cea708-types"
version = "0.4.0"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b4d0a0b75ef1572334ddfd197090a9f550b6a8852663cbb69e7af8e8c5641ef"
checksum = "de28b1d549e7f8f53a746fb36ae4c10c776a8e004950b527be1669a58667ae0b"
dependencies = [
"log",
"muldiv",

View file

@ -21,7 +21,7 @@ byteorder = "1"
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0", features = ["raw_value"] }
cdp-types = "0.3"
cea708-types = "0.4.0"
cea708-types = "0.4.1"
cea608-types = "0.1.1"
gst = { workspace = true, features = ["v1_20"]}
gst-base = { workspace = true, features = ["v1_22"]}

View file

@ -10,7 +10,7 @@ use std::collections::{HashMap, VecDeque};
use std::sync::Mutex;
use std::time::Duration;
use cea708_types::{CCDataParser, Service};
use cea708_types::CCDataParser;
use cea708_types::{CCDataWriter, DTVCCPacket, Framerate};
use gst::glib;
use gst::prelude::*;
@ -256,6 +256,7 @@ impl AggregatorImpl for Cea708Mux {
.all(|(_service_no, pending_codes)| pending_codes.is_empty())
&& state.writer.buffered_packet_duration() == Duration::ZERO
{
gst::info!(CAT, imp = self, "sending EOS");
return Err(gst::FlowError::Eos);
}
@ -263,11 +264,25 @@ impl AggregatorImpl for Cea708Mux {
.selected_samples(start_running_time, None, duration, None);
// phase 2: write stored data into output packet
let mut services = HashMap::new();
let mut packet = DTVCCPacket::new(state.dtvcc_seq_no & 0x3);
// an empty packet does not account for the packet header in free space calculations so do
// that manually here.
let mut free_space = packet.free_space() - 1;
let mut output = DTVCCPacket::new(state.dtvcc_seq_no & 0x3);
// try putting any previously overflowed service codes into a new
// output service.
for (service_no, pending_codes) in state.pending_services.iter_mut() {
while let Some(code) = pending_codes.pop_front() {
match output.push_code_into_single_service(*service_no, code.clone()) {
Ok(_) => (),
Err(cea708_types::WriterError::WouldOverflow(_)) => {
pending_codes.push_front(code);
break;
}
Err(
cea708_types::WriterError::ReadOnly
| cea708_types::WriterError::EmptyService,
) => unreachable!(),
}
}
}
for pad in sinkpads.iter().map(|pad| {
pad.downcast_ref::<super::Cea708MuxSinkPad>()
@ -281,45 +296,12 @@ impl AggregatorImpl for Cea708Mux {
CeaFormat::CcData => {
while let Some(packet) = pad_state.ccp_parser.pop_packet() {
for service in packet.services() {
let service_no = service.number();
if service.number() == 0 {
// skip null service
continue;
}
let new_service = services
.entry(service_no)
.or_insert_with_key(|&n| Service::new(n));
let prev_service_len = new_service.len();
// try putting any previously overflowed service codes into a new
// output service.
let mut overflowed = false;
if let Some(pending_codes) =
state.pending_services.get_mut(&service.number())
{
while let Some(code) = pending_codes.pop_front() {
if code.byte_len() + new_service.len() - prev_service_len
> free_space
{
overflowed = true;
pending_codes.push_front(code);
break;
}
match new_service.push_code(&code) {
Ok(_) => (),
Err(cea708_types::WriterError::WouldOverflow(_)) => {
overflowed = true;
pending_codes.push_front(code);
break;
}
Err(
cea708_types::WriterError::ReadOnly
| cea708_types::WriterError::EmptyService,
) => unreachable!(),
}
}
}
for code in service.codes() {
gst::trace!(
CAT,
@ -334,18 +316,10 @@ impl AggregatorImpl for Cea708Mux {
.or_default()
.push_back(code.clone());
} else {
if code.byte_len() + new_service.len() - prev_service_len
> free_space
{
overflowed = true;
state
.pending_services
.entry(service.number())
.or_default()
.push_back(code.clone());
continue;
}
match new_service.push_code(code) {
match output.push_code_into_single_service(
service.number(),
code.clone(),
) {
Ok(_) => (),
Err(cea708_types::WriterError::WouldOverflow(_)) => {
overflowed = true;
@ -362,63 +336,19 @@ impl AggregatorImpl for Cea708Mux {
}
}
}
gst::trace!(
CAT,
obj = pad,
"free_space: {free_space}, service len: {}",
new_service.len()
);
free_space -= new_service.len() - prev_service_len;
}
}
for (service_no, pending_codes) in state.pending_services.iter_mut() {
let new_service = services
.entry(*service_no)
.or_insert_with_key(|&n| Service::new(n));
let prev_service_len = new_service.len();
while let Some(code) = pending_codes.pop_front() {
if code.byte_len() + new_service.len() - prev_service_len > free_space {
pending_codes.push_front(code);
break;
}
match new_service.push_code(&code) {
Ok(_) => (),
Err(cea708_types::WriterError::WouldOverflow(_)) => {
pending_codes.push_front(code);
break;
}
Err(
cea708_types::WriterError::ReadOnly
| cea708_types::WriterError::EmptyService,
) => unreachable!(),
}
}
free_space -= new_service.len() - prev_service_len;
}
}
_ => (),
}
}
for (_service_no, service) in services.into_iter().filter(|(_, s)| !s.codes().is_empty()) {
// FIXME: handle needing to split services
gst::trace!(
CAT,
imp = self,
"Adding service {} to packet with sequence {}, {:?}",
service.number(),
packet.sequence_no(),
service.codes(),
);
packet.push_service(service).unwrap();
if packet.sequence_no() == state.dtvcc_seq_no & 0x3 {
state.dtvcc_seq_no = state.dtvcc_seq_no.wrapping_add(1);
}
if !output.is_empty() && output.sequence_no() == state.dtvcc_seq_no & 0x3 {
state.dtvcc_seq_no = state.dtvcc_seq_no.wrapping_add(1);
}
let mut data = vec![];
state.writer.push_packet(packet);
state.writer.push_packet(output);
let _ = state.writer.write(fps, &mut data);
state.n_frames += 1;
drop(state);

View file

@ -6,6 +6,8 @@
//
// SPDX-License-Identifier: MPL-2.0
use std::collections::HashMap;
use gst::prelude::*;
use pretty_assertions::assert_eq;
@ -22,14 +24,14 @@ fn init() {
});
}
fn gen_cc_data(seq: u8, service: u8, codes: &[Code]) -> gst::Buffer {
fn gen_cc_data(seq: u8, service_no: u8, codes: &[Code]) -> gst::Buffer {
assert!(seq < 4);
assert!(service < 64);
assert!(service_no < 64);
let fps = Framerate::new(30, 1);
let mut writer = CCDataWriter::default();
let mut packet = DTVCCPacket::new(seq);
let mut service = Service::new(service);
let mut service = Service::new(service_no);
for c in codes {
service.push_code(c).unwrap();
}
@ -37,7 +39,7 @@ fn gen_cc_data(seq: u8, service: u8, codes: &[Code]) -> gst::Buffer {
writer.push_packet(packet);
let mut data = vec![];
writer.write(fps, &mut data).unwrap();
println!("generated {data:x?}");
println!("generated {seq} for service {service_no} {data:x?}");
let data = data.split_off(2);
let mut buf = gst::Buffer::from_mut_slice(data);
{
@ -230,3 +232,137 @@ fn test_cea708mux_inputs_overflow_output() {
assert_eq!(&codes[..30], &CODES[no as usize..no as usize + 30]);
}
}
#[test]
fn test_cea708mux_inputs_overflow_output_new_service() {
init();
static CODES: [Code; 46] = [
Code::LatinLowerA,
Code::LatinLowerB,
Code::LatinLowerC,
Code::LatinLowerD,
Code::LatinLowerE,
Code::LatinLowerF,
Code::LatinLowerG,
Code::LatinLowerH,
Code::LatinLowerI,
Code::LatinLowerJ,
Code::LatinLowerK,
Code::LatinLowerL,
Code::LatinLowerM,
Code::LatinLowerN,
Code::LatinLowerO,
Code::LatinLowerP,
Code::LatinLowerQ,
Code::LatinLowerR,
Code::LatinLowerS,
Code::LatinLowerT,
Code::LatinLowerU,
Code::LatinLowerV,
Code::LatinLowerW,
Code::LatinLowerX,
Code::LatinLowerY,
Code::LatinLowerZ,
Code::LatinCapitalA,
Code::LatinCapitalB,
Code::LatinCapitalC,
Code::LatinCapitalD,
Code::LatinCapitalE,
Code::LatinCapitalF,
Code::LatinCapitalG,
Code::LatinCapitalH,
Code::LatinCapitalI,
Code::LatinCapitalJ,
Code::LatinCapitalK,
Code::LatinCapitalL,
Code::LatinCapitalM,
Code::LatinCapitalN,
Code::LatinCapitalO,
Code::LatinCapitalP,
Code::LatinCapitalQ,
Code::LatinCapitalR,
Code::LatinCapitalS,
Code::LatinCapitalT,
];
let mut h = gst_check::Harness::with_padnames("cea708mux", None, Some("src"));
let mut sinks = (0..6)
.map(|idx| {
let mut sink = gst_check::Harness::with_element(
&h.element().unwrap(),
Some(&format!("sink_{idx}")),
None,
);
sink.set_src_caps_str("closedcaption/x-cea-708,format=cc_data,framerate=60/1");
sink
})
.collect::<Vec<_>>();
let eos = gst::event::Eos::new();
for (i, sink) in sinks.iter_mut().enumerate() {
let buf = gen_cc_data(0, i as u8 + 1, &CODES[i..i + 30]);
sink.push(buf).unwrap();
}
for (i, sink) in sinks.iter_mut().enumerate() {
let i = 5 - i;
let buf = gen_cc_data(1, i as u8 + 1, &CODES[i + 6..i + 36]);
sink.push(buf).unwrap();
sink.push_event(eos.clone());
}
let mut parser = CCDataParser::new();
let mut seen_services: HashMap<u8, Vec<Code>, _> = HashMap::new();
let mut parsed_sequence_no = 0;
loop {
let mut parsed_packet = None;
while parsed_packet.is_none() {
let out = h.pull().unwrap();
let readable = out.map_readable().unwrap();
let mut cc_data = vec![0; 2];
cc_data[0] = 0x80 | 0x40 | ((readable.len() / 3) & 0x1f) as u8;
cc_data[1] = 0xFF;
cc_data.extend(readable.iter());
println!("pushed {cc_data:x?}");
parser.push(&cc_data).unwrap();
println!("parser: {parser:x?}");
parsed_packet = parser.pop_packet();
}
let parsed_packet = parsed_packet.unwrap();
println!("parsed: {parsed_packet:?}");
assert_eq!(parsed_packet.sequence_no(), parsed_sequence_no);
parsed_sequence_no += 1;
let services = parsed_packet.services();
// TODO: deterministic service ordering?
for service in services {
assert!((1..=6).contains(&service.number()));
seen_services
.entry(service.number())
.and_modify(|entry| entry.extend(service.codes().iter().cloned()))
.or_insert(service.codes().to_vec());
}
for (service_no, codes) in seen_services.iter() {
println!(
"seen service: {service_no}, codes (len: {}): {codes:?}",
codes.len()
);
}
if seen_services.keys().len() >= 6 && seen_services.values().all(|svc| svc.len() >= 60) {
break;
}
}
let mut service_numbers = seen_services.keys().copied().collect::<Vec<_>>();
service_numbers.sort();
assert_eq!(service_numbers, (1..=6).collect::<Vec<_>>());
for no in service_numbers {
let codes = seen_services.get(&no).unwrap();
println!("service {no}: {:?}", codes);
let offset = no as usize - 1;
// one of the services will have a length that is 1 byte shorter than others due to size
// limits of the packet.
assert_eq!(60, codes.len());
assert_eq!(&codes[..30], &CODES[offset..offset + 30]);
assert_eq!(&codes[30..60], &CODES[offset + 6..offset + 36]);
}
}