gst-plugins-rs/net/rtp/src/av1/depay/imp.rs

583 lines
19 KiB
Rust

//
// Copyright (C) 2022 Vivienne Watermeier <vwatermeier@igalia.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
// If a copy of the MPL was not distributed with this file, You can obtain one at
// <https://mozilla.org/MPL/2.0/>.
//
// SPDX-License-Identifier: MPL-2.0
use gst::{glib, subclass::prelude::*};
use gst_rtp::prelude::*;
use gst_rtp::subclass::prelude::*;
use std::{
cmp::Ordering,
io::{Cursor, Read, Seek, SeekFrom},
sync::Mutex,
};
use bitstream_io::{BitReader, BitWriter};
use once_cell::sync::Lazy;
use crate::av1::common::{
err_opt, leb128_size, parse_leb128, write_leb128, AggregationHeader, ObuType, SizedObu,
UnsizedObu, CLOCK_RATE, ENDIANNESS,
};
// TODO: handle internal size fields in RTP OBUs
#[derive(Debug, Default)]
struct State {
/// used to store outgoing OBUs until the TU is complete
adapter: gst_base::UniqueAdapter,
last_timestamp: Option<u32>,
/// if true, the last packet of a temporal unit has been received
marked_packet: bool,
/// holds data for a fragment
obu_fragment: Option<(UnsizedObu, Vec<u8>)>,
}
#[derive(Debug, Default)]
pub struct RTPAv1Depay {
state: Mutex<State>,
}
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
gst::DebugCategory::new(
"rtpav1depay",
gst::DebugColorFlags::empty(),
Some("RTP AV1 Depayloader"),
)
});
static TEMPORAL_DELIMITER: Lazy<gst::Memory> =
Lazy::new(|| gst::Memory::from_slice([0b0001_0010, 0]));
impl RTPAv1Depay {
fn reset(&self, state: &mut State) {
gst::debug!(CAT, imp: self, "resetting state");
*state = State::default()
}
}
#[glib::object_subclass]
impl ObjectSubclass for RTPAv1Depay {
const NAME: &'static str = "GstRtpAv1Depay";
type Type = super::RTPAv1Depay;
type ParentType = gst_rtp::RTPBaseDepayload;
}
impl ObjectImpl for RTPAv1Depay {}
impl GstObjectImpl for RTPAv1Depay {}
impl ElementImpl for RTPAv1Depay {
fn metadata() -> Option<&'static gst::subclass::ElementMetadata> {
static ELEMENT_METADATA: Lazy<gst::subclass::ElementMetadata> = Lazy::new(|| {
gst::subclass::ElementMetadata::new(
"RTP AV1 Depayloader",
"Codec/Depayloader/Network/RTP",
"Depayload AV1 from RTP packets",
"Vivienne Watermeier <vwatermeier@igalia.com>",
)
});
Some(&*ELEMENT_METADATA)
}
fn pad_templates() -> &'static [gst::PadTemplate] {
static PAD_TEMPLATES: Lazy<Vec<gst::PadTemplate>> = Lazy::new(|| {
let sink_pad_template = gst::PadTemplate::new(
"sink",
gst::PadDirection::Sink,
gst::PadPresence::Always,
&gst::Caps::builder("application/x-rtp")
.field("media", "video")
.field("payload", gst::IntRange::new(96, 127))
.field("clock-rate", CLOCK_RATE as i32)
.field("encoding-name", "AV1")
.build(),
)
.unwrap();
let src_pad_template = gst::PadTemplate::new(
"src",
gst::PadDirection::Src,
gst::PadPresence::Always,
&gst::Caps::builder("video/x-av1")
.field("parsed", true)
.field("stream-format", "obu-stream")
.field("alignment", "tu")
.build(),
)
.unwrap();
vec![src_pad_template, sink_pad_template]
});
PAD_TEMPLATES.as_ref()
}
fn change_state(
&self,
transition: gst::StateChange,
) -> Result<gst::StateChangeSuccess, gst::StateChangeError> {
gst::debug!(CAT, imp: self, "changing state: {}", transition);
if matches!(transition, gst::StateChange::ReadyToPaused) {
let mut state = self.state.lock().unwrap();
self.reset(&mut state);
}
let ret = self.parent_change_state(transition);
if matches!(transition, gst::StateChange::PausedToReady) {
let mut state = self.state.lock().unwrap();
self.reset(&mut state);
}
ret
}
}
impl RTPBaseDepayloadImpl for RTPAv1Depay {
fn set_caps(&self, _caps: &gst::Caps) -> Result<(), gst::LoggableError> {
let element = self.obj();
let src_pad = element.src_pad();
let src_caps = src_pad.pad_template_caps();
src_pad.push_event(gst::event::Caps::builder(&src_caps).build());
Ok(())
}
fn handle_event(&self, event: gst::Event) -> bool {
match event.view() {
gst::EventView::Eos(_) | gst::EventView::FlushStop(_) => {
let mut state = self.state.lock().unwrap();
self.reset(&mut state);
}
_ => (),
}
self.parent_handle_event(event)
}
fn process_rtp_packet(
&self,
rtp: &gst_rtp::RTPBuffer<gst_rtp::rtp_buffer::Readable>,
) -> Option<gst::Buffer> {
gst::log!(
CAT,
imp: self,
"processing RTP packet with payload type {} and size {}",
rtp.payload_type(),
rtp.buffer().size(),
);
let payload = rtp.payload().map_err(err_opt!(self, payload_buf)).ok()?;
let mut state = self.state.lock().unwrap();
if rtp.buffer().flags().contains(gst::BufferFlags::DISCONT) {
gst::debug!(CAT, imp: self, "buffer discontinuity");
self.reset(&mut state);
}
// number of bytes that can be used in the next outgoing buffer
let mut bytes_ready = 0;
let mut reader = Cursor::new(payload);
let mut ready_obus = gst::Buffer::new();
let aggr_header = {
let mut byte = [0; 1];
reader
.read_exact(&mut byte)
.map_err(err_opt!(self, aggr_header_read))
.ok()?;
AggregationHeader::from(&byte)
};
// handle new temporal units
if state.marked_packet || state.last_timestamp != Some(rtp.timestamp()) {
if state.last_timestamp.is_some() && state.obu_fragment.is_some() {
gst::error!(
CAT,
imp: self,
concat!(
"invalid packet: packet is part of a new TU but ",
"the previous TU still has an incomplete OBU",
"marked_packet: {}, last_timestamp: {:?}"
),
state.marked_packet,
state.last_timestamp
);
self.reset(&mut state);
return None;
}
// all the currently stored bytes can be packed into the next outgoing buffer
bytes_ready = state.adapter.available();
// the next temporal unit starts with a temporal delimiter OBU
ready_obus
.get_mut()
.unwrap()
.insert_memory(None, TEMPORAL_DELIMITER.clone());
state.marked_packet = false;
}
state.marked_packet = rtp.is_marker();
state.last_timestamp = Some(rtp.timestamp());
// parse and prepare the received OBUs
let mut idx = 0;
// handle leading OBU fragment
if let Some((obu, ref mut bytes)) = &mut state.obu_fragment {
if !aggr_header.leading_fragment {
gst::error!(
CAT,
imp: self,
"invalid packet: ignores unclosed OBU fragment"
);
return None;
}
let (element_size, is_last_obu) =
self.find_element_info(rtp, &mut reader, &aggr_header, idx)?;
let bytes_end = bytes.len();
bytes.resize(bytes_end + element_size as usize, 0);
reader
.read_exact(&mut bytes[bytes_end..])
.map_err(err_opt!(self, buf_read))
.ok()?;
// if this OBU is complete, it can be appended to the adapter
if !(is_last_obu && aggr_header.trailing_fragment) {
let full_obu = {
let size = bytes.len() as u32 - obu.header_len;
let leb_size = leb128_size(size) as u32;
obu.as_sized(size, leb_size)
};
let buffer = self.translate_obu(&mut Cursor::new(bytes.as_slice()), &full_obu)?;
state.adapter.push(buffer);
state.obu_fragment = None;
}
}
// handle other OBUs, including trailing fragments
while reader.position() < rtp.payload_size() as u64 {
let (element_size, is_last_obu) =
self.find_element_info(rtp, &mut reader, &aggr_header, idx)?;
let header_pos = reader.position();
let mut bitreader = BitReader::endian(&mut reader, ENDIANNESS);
let obu = UnsizedObu::parse(&mut bitreader)
.map_err(err_opt!(self, obu_read))
.ok()?;
reader
.seek(SeekFrom::Start(header_pos))
.map_err(err_opt!(self, buf_read))
.ok()?;
// ignore these OBU types
if matches!(obu.obu_type, ObuType::TemporalDelimiter | ObuType::TileList) {
reader
.seek(SeekFrom::Current(element_size as i64))
.map_err(err_opt!(self, buf_read))
.ok()?;
}
// trailing OBU fragments are stored in the state
if is_last_obu && aggr_header.trailing_fragment {
let bytes_left = rtp.payload_size() - (reader.position() as u32);
let mut bytes = vec![0; bytes_left as usize];
reader
.read_exact(bytes.as_mut_slice())
.map_err(err_opt!(self, buf_read))
.ok()?;
state.obu_fragment = Some((obu, bytes));
}
// full OBUs elements are translated and appended to the adapter
else {
let full_obu = {
let size = element_size - obu.header_len;
let leb_size = leb128_size(size) as u32;
obu.as_sized(size, leb_size)
};
ready_obus.append(self.translate_obu(&mut reader, &full_obu)?);
}
idx += 1;
}
state.adapter.push(ready_obus);
if state.marked_packet {
if state.obu_fragment.is_some() {
gst::error!(
CAT,
imp: self,
concat!(
"invalid packet: has marker bit set, but ",
"last OBU is not yet complete"
)
);
self.reset(&mut state);
return None;
}
bytes_ready = state.adapter.available();
}
// now push all the complete temporal units
if bytes_ready > 0 {
gst::log!(
CAT,
imp: self,
"creating buffer containing {} bytes of data...",
bytes_ready
);
Some(
state
.adapter
.take_buffer(bytes_ready)
.map_err(err_opt!(self, buf_take))
.ok()?,
)
} else {
None
}
}
}
impl RTPAv1Depay {
/// Find out the next OBU element's size, and if it is the last OBU in the packet.
/// The reader is expected to be at the first byte of the element,
/// or its preceding size field if present,
/// and will be at the first byte past the element's size field afterwards.
fn find_element_info(
&self,
rtp: &gst_rtp::RTPBuffer<gst_rtp::rtp_buffer::Readable>,
reader: &mut Cursor<&[u8]>,
aggr_header: &AggregationHeader,
index: u32,
) -> Option<(u32, bool)> {
let element_size: u32;
let is_last_obu: bool;
if let Some(count) = aggr_header.obu_count {
is_last_obu = index + 1 == count as u32;
element_size = if is_last_obu {
rtp.payload_size() - (reader.position() as u32)
} else {
let mut bitreader = BitReader::endian(reader, ENDIANNESS);
parse_leb128(&mut bitreader)
.map_err(err_opt!(self, leb_read))
.ok()?
}
} else {
element_size = parse_leb128(&mut BitReader::endian(&mut *reader, ENDIANNESS))
.map_err(err_opt!(self, leb_read))
.ok()?;
is_last_obu = match rtp
.payload_size()
.cmp(&(reader.position() as u32 + element_size))
{
Ordering::Greater => false,
Ordering::Equal => true,
Ordering::Less => {
gst::error!(
CAT,
imp: self,
"invalid packet: size field gives impossibly large OBU size"
);
return None;
}
};
}
Some((element_size, is_last_obu))
}
/// Using OBU data from an RTP packet, construct a buffer containing that OBU in AV1 bitstream format
fn translate_obu(&self, reader: &mut Cursor<&[u8]>, obu: &SizedObu) -> Option<gst::Buffer> {
let mut bytes = gst::Buffer::with_size(obu.full_size() as usize)
.map_err(err_opt!(self, buf_alloc))
.ok()?
.into_mapped_buffer_writable()
.unwrap();
// write OBU header
reader
.read_exact(&mut bytes[..obu.header_len as usize])
.map_err(err_opt!(self, buf_read))
.ok()?;
// set `has_size_field`
bytes[0] |= 1 << 1;
// skip internal size field if present
if obu.has_size_field {
parse_leb128(&mut BitReader::endian(&mut *reader, ENDIANNESS))
.map_err(err_opt!(self, leb_read))
.ok()?;
}
// write size field
write_leb128(
&mut BitWriter::endian(
Cursor::new(&mut bytes[obu.header_len as usize..]),
ENDIANNESS,
),
obu.size,
)
.map_err(err_opt!(self, leb_write))
.ok()?;
// write OBU payload
reader
.read_exact(&mut bytes[(obu.header_len + obu.leb_size) as usize..])
.map_err(err_opt!(self, buf_read))
.ok()?;
Some(bytes.into_buffer())
}
}
#[cfg(test)]
#[rustfmt::skip]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_translate_obu() {
gst::init().unwrap();
let test_data = [
(
SizedObu {
obu_type: ObuType::TemporalDelimiter,
has_extension: false,
has_size_field: false,
temporal_id: 0,
spatial_id: 0,
size: 0,
leb_size: 1,
header_len: 1,
is_fragment: false,
},
vec![0b0001_0000],
vec![0b0001_0010, 0],
), (
SizedObu {
obu_type: ObuType::Frame,
has_extension: true,
has_size_field: false,
temporal_id: 3,
spatial_id: 2,
size: 5,
leb_size: 1,
header_len: 2,
is_fragment: false,
},
vec![0b0011_0100, 0b0111_0000, 1, 2, 3, 4, 5],
vec![0b0011_0110, 0b0111_0000, 0b0000_0101, 1, 2, 3, 4, 5],
), (
SizedObu {
obu_type: ObuType::Frame,
has_extension: true,
has_size_field: true,
temporal_id: 3,
spatial_id: 2,
size: 5,
leb_size: 1,
header_len: 2,
is_fragment: false,
},
vec![0b0011_0100, 0b0111_0000, 0b0000_0101, 1, 2, 3, 4, 5],
vec![0b0011_0110, 0b0111_0000, 0b0000_0101, 1, 2, 3, 4, 5],
)
];
let element = <RTPAv1Depay as ObjectSubclass>::Type::new();
for (idx, (obu, rtp_bytes, out_bytes)) in test_data.into_iter().enumerate() {
println!("running test {idx}...");
let mut reader = Cursor::new(rtp_bytes.as_slice());
let actual = element.imp().translate_obu(&mut reader, &obu);
assert_eq!(reader.position(), rtp_bytes.len() as u64);
assert!(actual.is_some());
let actual = actual
.unwrap()
.into_mapped_buffer_readable()
.unwrap();
assert_eq!(actual.as_slice(), out_bytes.as_slice());
}
}
#[test]
#[allow(clippy::type_complexity)]
fn test_find_element_info() {
gst::init().unwrap();
let test_data: [(Vec<(u32, bool)>, u32, Vec<u8>, AggregationHeader); 4] = [
(
vec![(1, false)], // expected results
100, // RTP payload size
vec![0b0000_0001, 0b0001_0000],
AggregationHeader { obu_count: None, ..AggregationHeader::default() },
), (
vec![(5, true)],
5,
vec![0b0111_1000, 0, 0, 0, 0],
AggregationHeader { obu_count: Some(1), ..AggregationHeader::default() },
), (
vec![(7, true)],
8,
vec![0b0000_0111, 0b0011_0110, 0b0010_1000, 0b0000_1010, 1, 2, 3, 4],
AggregationHeader { obu_count: None, ..AggregationHeader::default() },
), (
vec![(6, false), (4, true)],
11,
vec![0b0000_0110, 0b0111_1000, 1, 2, 3, 4, 5, 0b0011_0000, 1, 2, 3],
AggregationHeader { obu_count: Some(2), ..AggregationHeader::default() },
)
];
let element = <RTPAv1Depay as ObjectSubclass>::Type::new();
for (idx, (
info,
payload_size,
rtp_bytes,
aggr_header,
)) in test_data.into_iter().enumerate() {
println!("running test {idx}...");
let buffer = gst::Buffer::new_rtp_with_sizes(payload_size, 0, 0).unwrap();
let rtp = gst_rtp::RTPBuffer::from_buffer_readable(&buffer).unwrap();
let mut reader = Cursor::new(rtp_bytes.as_slice());
let mut element_size = 0;
for (obu_idx, expected) in info.into_iter().enumerate() {
if element_size != 0 {
reader.seek(SeekFrom::Current(element_size as i64)).unwrap();
}
println!("testing element {} with reader position {}...", obu_idx, reader.position());
let actual = element.imp().find_element_info(&rtp, &mut reader, &aggr_header, obu_idx as u32);
assert_eq!(actual, Some(expected));
element_size = actual.unwrap().0;
}
}
}
}