diff --git a/video/closedcaption/src/cea708mux/imp.rs b/video/closedcaption/src/cea708mux/imp.rs index 1aa52b4e6..4e9eb5a3f 100644 --- a/video/closedcaption/src/cea708mux/imp.rs +++ b/video/closedcaption/src/cea708mux/imp.rs @@ -8,6 +8,7 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Mutex; +use std::time::Duration; use cea708_types::{CCDataParser, Service}; use cea708_types::{CCDataWriter, DTVCCPacket, Framerate}; @@ -89,6 +90,7 @@ struct State { dtvcc_seq_no: u8, writer: CCDataWriter, n_frames: u64, + pending_services: HashMap>, } impl Default for State { @@ -102,6 +104,7 @@ impl Default for State { dtvcc_seq_no: 0, writer, n_frames: 0, + pending_services: HashMap::default(), } } } @@ -246,7 +249,13 @@ impl AggregatorImpl for Cea708Mux { if need_data && !timeout { return Err(gst_base::AGGREGATOR_FLOW_NEED_DATA); } - if all_eos { + if all_eos + && state + .pending_services + .iter() + .all(|(_service_no, pending_codes)| pending_codes.is_empty()) + && state.writer.buffered_packet_duration() == Duration::ZERO + { return Err(gst::FlowError::Eos); } @@ -255,6 +264,10 @@ impl AggregatorImpl for Cea708Mux { // phase 2: write stored data into output packet let mut services = HashMap::new(); + let mut packet = DTVCCPacket::new(state.dtvcc_seq_no & 0x3); + // an empty packet does not account for the packet header in free space calculations so do + // that manually here. + let mut free_space = packet.free_space() - 1; for pad in sinkpads.iter().map(|pad| { pad.downcast_ref::() @@ -276,12 +289,22 @@ impl AggregatorImpl for Cea708Mux { let new_service = services .entry(service_no) .or_insert_with_key(|&n| Service::new(n)); + let prev_service_len = new_service.len(); + // try putting any previously overflowed service codes into a new + // output service. let mut overflowed = false; if let Some(pending_codes) = - pad_state.pending_services.get_mut(&service.number()) + state.pending_services.get_mut(&service.number()) { while let Some(code) = pending_codes.pop_front() { + if code.byte_len() + new_service.len() - prev_service_len + > free_space + { + overflowed = true; + pending_codes.push_front(code); + break; + } match new_service.push_code(&code) { Ok(_) => (), Err(cea708_types::WriterError::WouldOverflow(_)) => { @@ -302,17 +325,28 @@ impl AggregatorImpl for Cea708Mux { service.number() ); if overflowed { - pad_state + state .pending_services .entry(service.number()) .or_default() .push_back(code.clone()); } else { + if code.byte_len() + new_service.len() - prev_service_len + > free_space + { + overflowed = true; + state + .pending_services + .entry(service.number()) + .or_default() + .push_back(code.clone()); + continue; + } match new_service.push_code(code) { Ok(_) => (), Err(cea708_types::WriterError::WouldOverflow(_)) => { overflowed = true; - pad_state + state .pending_services .entry(service.number()) .or_default() @@ -322,14 +356,26 @@ impl AggregatorImpl for Cea708Mux { } } } + gst::trace!( + CAT, + obj = pad, + "free_space: {free_space}, service len: {}", + new_service.len() + ); + free_space -= new_service.len() - prev_service_len; } } - for (service_no, pending_codes) in pad_state.pending_services.iter_mut() { + for (service_no, pending_codes) in state.pending_services.iter_mut() { let new_service = services .entry(*service_no) .or_insert_with_key(|&n| Service::new(n)); + let prev_service_len = new_service.len(); while let Some(code) = pending_codes.pop_front() { + if code.byte_len() + new_service.len() - prev_service_len > free_space { + pending_codes.push_front(code); + break; + } match new_service.push_code(&code) { Ok(_) => (), Err(cea708_types::WriterError::WouldOverflow(_)) => { @@ -345,15 +391,14 @@ impl AggregatorImpl for Cea708Mux { } } - let mut packet = DTVCCPacket::new(state.dtvcc_seq_no & 0x3); - for (_service_no, service) in services.into_iter().filter(|(_, s)| !s.codes().is_empty()) { // FIXME: handle needing to split services gst::trace!( CAT, imp = self, - "Adding service {} to packet", - service.number() + "Adding service {} to packet with sequence {}", + service.number(), + packet.sequence_no(), ); packet.push_service(service).unwrap(); if packet.sequence_no() == state.dtvcc_seq_no & 0x3 { @@ -666,7 +711,6 @@ impl ObjectSubclass for Cea708Mux { struct PadState { format: CeaFormat, ccp_parser: CCDataParser, - pending_services: HashMap>, pending_buffer: Option, } @@ -677,7 +721,6 @@ impl Default for PadState { Self { format: CeaFormat::default(), ccp_parser, - pending_services: HashMap::default(), pending_buffer: None, } } diff --git a/video/closedcaption/tests/cea708mux.rs b/video/closedcaption/tests/cea708mux.rs index ff7b2e356..b64985ab0 100644 --- a/video/closedcaption/tests/cea708mux.rs +++ b/video/closedcaption/tests/cea708mux.rs @@ -42,6 +42,7 @@ fn gen_cc_data(seq: u8, service: u8, codes: &[Code]) -> gst::Buffer { { let buf = buf.get_mut().unwrap(); buf.set_pts(0.nseconds()); + buf.set_duration(gst::ClockTime::from_mseconds(400)); } buf } @@ -128,3 +129,105 @@ fn test_cea708mux_2pads_cc_data() { unreachable!(); } } + +#[test] +fn test_cea708mux_inputs_overflow_output() { + init(); + + static CODES: [Code; 36] = [ + Code::LatinLowerA, + Code::LatinLowerB, + Code::LatinLowerC, + Code::LatinLowerD, + Code::LatinLowerE, + Code::LatinLowerF, + Code::LatinLowerG, + Code::LatinLowerH, + Code::LatinLowerI, + Code::LatinLowerJ, + Code::LatinLowerK, + Code::LatinLowerL, + Code::LatinLowerM, + Code::LatinLowerN, + Code::LatinLowerO, + Code::LatinLowerP, + Code::LatinLowerQ, + Code::LatinLowerR, + Code::LatinLowerS, + Code::LatinLowerT, + Code::LatinLowerU, + Code::LatinLowerV, + Code::LatinLowerW, + Code::LatinLowerX, + Code::LatinLowerY, + Code::LatinLowerZ, + Code::LatinCapitalA, + Code::LatinCapitalB, + Code::LatinCapitalC, + Code::LatinCapitalD, + Code::LatinCapitalE, + Code::LatinCapitalF, + Code::LatinCapitalG, + Code::LatinCapitalH, + Code::LatinCapitalI, + Code::LatinCapitalJ, + ]; + + let mut h = gst_check::Harness::with_padnames("cea708mux", None, Some("src")); + let mut sink_0 = gst_check::Harness::with_element(&h.element().unwrap(), Some("sink_0"), None); + sink_0.set_src_caps_str("closedcaption/x-cea-708,format=cc_data,framerate=60/1"); + let mut sink_1 = gst_check::Harness::with_element(&h.element().unwrap(), Some("sink_1"), None); + sink_1.set_src_caps_str("closedcaption/x-cea-708,format=cc_data,framerate=60/1"); + let mut sink_2 = gst_check::Harness::with_element(&h.element().unwrap(), Some("sink_2"), None); + sink_2.set_src_caps_str("closedcaption/x-cea-708,format=cc_data,framerate=60/1"); + let mut sink_3 = gst_check::Harness::with_element(&h.element().unwrap(), Some("sink_3"), None); + sink_3.set_src_caps_str("closedcaption/x-cea-708,format=cc_data,framerate=60/1"); + + let eos = gst::event::Eos::new(); + + let buf = gen_cc_data(0, 1, &CODES[1..32]); + sink_0.push(buf).unwrap(); + sink_0.push_event(eos.clone()); + + let buf = gen_cc_data(0, 2, &CODES[2..33]); + sink_1.push(buf).unwrap(); + sink_1.push_event(eos.clone()); + + let buf = gen_cc_data(0, 3, &CODES[3..34]); + sink_2.push(buf).unwrap(); + sink_2.push_event(eos.clone()); + + let buf = gen_cc_data(0, 4, &CODES[4..35]); + sink_3.push(buf).unwrap(); + sink_3.push_event(eos.clone()); + + let mut parser = CCDataParser::new(); + let mut parsed_packet = None; + while parsed_packet.is_none() { + let out = h.pull().unwrap(); + let readable = out.map_readable().unwrap(); + let mut cc_data = vec![0; 2]; + cc_data[0] = 0x80 | 0x40 | ((readable.len() / 3) & 0x1f) as u8; + cc_data[1] = 0xFF; + cc_data.extend(readable.iter()); + println!("pushed {cc_data:x?}"); + parser.push(&cc_data).unwrap(); + println!("parser: {parser:x?}"); + parsed_packet = parser.pop_packet(); + } + let parsed_packet = parsed_packet.unwrap(); + println!("parsed: {parsed_packet:?}"); + assert_eq!(parsed_packet.sequence_no(), 0); + let services = parsed_packet.services(); + assert_eq!(services.len(), 4); + // TODO: deterministic service ordering? + for service in services { + let codes = service.codes(); + assert!((1..=4).contains(&service.number())); + let no = service.number(); + // one of the services will have a length that is 1 byte shorter than others due to size + // limits of the packet. + assert!((30..=31).contains(&codes.len())); + assert_eq!(&codes[..30], &CODES[no as usize..no as usize + 30]); + } +}