plugin: Implement bandwidth estimator based on the Google Congestion Control algorithm

See https://datatracker.ietf.org/doc/html/draft-ietf-rmcat-gcc-02

This commit implements the bandwidth estimation as a GStreamer element
that is then used in webrtcbin through the new `request-bandwidth-estimator`
signal.

This keeps our Homegrown congestion controller but removes the possibility
to switch CC algorithm at runtime.
This commit is contained in:
Thibault Saunier 2022-05-11 16:58:53 -04:00 committed by Mathieu Duponchelle
parent 287e76847a
commit 64f664c859
7 changed files with 1554 additions and 100 deletions

1
Cargo.lock generated
View file

@ -1856,6 +1856,7 @@ dependencies = [
"async-native-tls 0.4.0",
"async-std",
"async-tungstenite",
"chrono",
"clap",
"fastrand",
"futures",

View file

@ -17,6 +17,7 @@ gst-sdp = { git="https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", package
gst-rtp = { git="https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", package = "gstreamer-rtp", features = ["v1_20"] }
gst-utils = { git="https://gitlab.freedesktop.org/gstreamer/gstreamer-rs", package = "gstreamer-utils" }
once_cell = "1.0"
chrono = { version = "0.4", default-features = false }
smallvec = "1"
anyhow = "1"
thiserror = "1"

1366
plugins/src/gcc/imp.rs Normal file

File diff suppressed because it is too large Load diff

16
plugins/src/gcc/mod.rs Normal file
View file

@ -0,0 +1,16 @@
use gst::glib;
use gst::prelude::*;
mod imp;
glib::wrapper! {
pub struct BandwidthEstimator(ObjectSubclass<imp::BandwidthEstimator>) @extends gst::Element, gst::Object;
}
pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
gst::Element::register(
Some(plugin),
"rtpgccbwe",
gst::Rank::None,
BandwidthEstimator::static_type(),
)
}

View file

@ -1,10 +1,12 @@
use gst::glib;
pub mod gcc;
mod signaller;
pub mod webrtcsink;
fn plugin_init(plugin: &gst::Plugin) -> Result<(), glib::BoolError> {
webrtcsink::register(plugin)?;
gcc::register(plugin)?;
Ok(())
}

View file

@ -47,19 +47,23 @@ const DEFAULT_CONGESTION_CONTROL: WebRTCSinkCongestionControl =
const DEFAULT_DO_FEC: bool = true;
const DEFAULT_DO_RETRANSMISSION: bool = true;
const DEFAULT_ENABLE_DATA_CHANNEL_NAVIGATION: bool = false;
const DEFAULT_START_BITRATE: u32 = 2048000;
#[derive(Debug, Clone, Copy)]
struct CCInfo {
heuristic: WebRTCSinkCongestionControl,
min_bitrate: u32,
max_bitrate: u32,
start_bitrate: u32,
}
/// User configuration
struct Settings {
video_caps: gst::Caps,
audio_caps: gst::Caps,
turn_server: Option<String>,
stun_server: Option<String>,
cc_heuristic: WebRTCSinkCongestionControl,
min_bitrate: u32,
max_bitrate: u32,
start_bitrate: u32,
cc_info: CCInfo,
do_fec: bool,
do_retransmission: bool,
enable_data_channel_navigation: bool,
@ -114,7 +118,7 @@ struct WebRTCPad {
/// name in order to provide a unified set / get bitrate API, also
/// tracks a raw capsfilter used to resize / decimate the input video
/// stream according to the bitrate, thresholds hardcoded for now
struct VideoEncoder {
pub struct VideoEncoder {
factory_name: String,
codec_name: String,
element: gst::Element,
@ -186,19 +190,18 @@ enum ControllerType {
// Running the "loss based controller"
Loss,
}
struct Consumer {
pipeline: gst::Pipeline,
webrtcbin: gst::Element,
webrtc_pads: HashMap<u32, WebRTCPad>,
peer_id: String,
encoders: Vec<VideoEncoder>,
/// None if congestion control was disabled
// Our Homegrown controller
congestion_controller: Option<CongestionController>,
sdp: Option<gst_sdp::SDPMessage>,
stats: gst::Structure,
max_bitrate: u32,
cc_info: CCInfo,
links: HashMap<u32, gst_utils::ConsumptionLink>,
stats_sigid: Option<glib::SignalHandlerId>,
@ -308,12 +311,14 @@ impl Default for Settings {
.iter()
.map(|s| gst::Structure::new_empty(s))
.collect::<gst::Caps>(),
cc_heuristic: WebRTCSinkCongestionControl::Homegrown,
stun_server: DEFAULT_STUN_SERVER.map(String::from),
turn_server: None,
min_bitrate: DEFAULT_MIN_BITRATE,
max_bitrate: DEFAULT_MAX_BITRATE,
start_bitrate: DEFAULT_START_BITRATE,
cc_info: CCInfo {
heuristic: WebRTCSinkCongestionControl::Homegrown,
min_bitrate: DEFAULT_MIN_BITRATE,
max_bitrate: (DEFAULT_MAX_BITRATE as f64 * 1.5) as u32,
start_bitrate: DEFAULT_START_BITRATE,
},
do_fec: DEFAULT_DO_FEC,
do_retransmission: DEFAULT_DO_RETRANSMISSION,
enable_data_channel_navigation: DEFAULT_ENABLE_DATA_CHANNEL_NAVIGATION,
@ -577,13 +582,15 @@ fn setup_encoding(
Ok((enc, conv_filter, pay))
}
fn lookup_transport_stats(stats: &gst::StructureRef) -> Option<gst::Structure> {
fn lookup_twcc_stats(stats: &gst::StructureRef) -> Option<gst::Structure> {
for (_, field_value) in stats {
if let Ok(s) = field_value.get::<gst::Structure>() {
if let Ok(type_) = s.get::<gst_webrtc::WebRTCStatsType>("type") {
if type_ == gst_webrtc::WebRTCStatsType::Transport && s.has_field("gst-twcc-stats")
if (type_ == gst_webrtc::WebRTCStatsType::Transport
|| type_ == gst_webrtc::WebRTCStatsType::CandidatePair)
&& s.has_field("gst-twcc-stats")
{
return Some(s);
return Some(s.get::<gst::Structure>("gst-twcc-stats").unwrap());
}
}
}
@ -616,7 +623,7 @@ impl VideoEncoder {
}
}
fn bitrate(&self) -> i32 {
pub fn bitrate(&self) -> i32 {
match self.factory_name.as_str() {
"vp8enc" | "vp9enc" => self.element.property::<i32>("target-bitrate"),
"x264enc" | "nvh264enc" | "vaapih264enc" | "vaapivp8enc" => {
@ -626,7 +633,7 @@ impl VideoEncoder {
}
}
fn scale_height_round_2(&self, height: i32) -> i32 {
pub fn scale_height_round_2(&self, height: i32) -> i32 {
let ratio = gst_video::calculate_display_ratio(
self.video_info.width(),
self.video_info.height(),
@ -640,7 +647,7 @@ impl VideoEncoder {
(width + 1) & !1
}
fn set_bitrate(&mut self, element: &super::WebRTCSink, bitrate: i32) {
pub fn set_bitrate(&mut self, element: &super::WebRTCSink, bitrate: i32) {
match self.factory_name.as_str() {
"vp8enc" | "vp9enc" => self.element.set_property("target-bitrate", bitrate),
"x264enc" | "nvh264enc" | "vaapih264enc" | "vaapivp8enc" => self
@ -951,9 +958,7 @@ impl CongestionController {
stats: &gst::StructureRef,
encoders: &mut Vec<VideoEncoder>,
) {
if let Some(twcc_stats) = lookup_transport_stats(stats).and_then(|transport_stats| {
transport_stats.get::<gst::Structure>("gst-twcc-stats").ok()
}) {
if let Some(twcc_stats) = lookup_twcc_stats(stats) {
let op = self.update_delay(element, &twcc_stats, self.lookup_rtt(stats));
self.apply_control_op(element, encoders, op, ControllerType::Delay);
}
@ -1024,10 +1029,14 @@ impl CongestionController {
if target_bitrate != prev_bitrate {
gst::info!(
CAT,
"{:?} {} => {}",
"{:?} {} => {} | on delay {} - on loss {} | min {} - max {}",
control_op,
human_bytes::human_bytes(prev_bitrate),
human_bytes::human_bytes(target_bitrate)
human_bytes::human_bytes(target_bitrate),
human_bytes::human_bytes(self.target_bitrate_on_delay),
human_bytes::human_bytes(self.target_bitrate_on_loss),
human_bytes::human_bytes(self.min_bitrate),
human_bytes::human_bytes(self.max_bitrate),
);
}
@ -1123,16 +1132,16 @@ impl Consumer {
webrtcbin: gst::Element,
peer_id: String,
congestion_controller: Option<CongestionController>,
max_bitrate: u32,
cc_info: CCInfo,
) -> Self {
Self {
pipeline,
webrtcbin,
peer_id,
cc_info,
congestion_controller,
max_bitrate,
sdp: None,
stats: gst::Structure::new_empty("application/x-webrtc-stats"),
sdp: None,
webrtc_pads: HashMap::new(),
encoders: Vec::new(),
links: HashMap::new(),
@ -1326,16 +1335,27 @@ impl Consumer {
transceiver,
);
if let Some(congestion_controller) = self.congestion_controller.as_mut() {
congestion_controller.target_bitrate_on_delay += enc.bitrate();
congestion_controller.target_bitrate_on_loss =
congestion_controller.target_bitrate_on_delay;
enc.transceiver.set_property("fec-percentage", 0u32);
} else {
/* If congestion control is disabled, we simply use the highest
* known "safe" value for the bitrate. */
enc.set_bitrate(element, self.max_bitrate as i32);
enc.transceiver.set_property("fec-percentage", 50u32);
match self.cc_info.heuristic {
WebRTCSinkCongestionControl::Disabled => {
// If congestion control is disabled, we simply use the highest
// known "safe" value for the bitrate.
enc.set_bitrate(element, (self.cc_info.max_bitrate as f64 / 1.5) as i32);
enc.transceiver.set_property("fec-percentage", 50u32);
}
WebRTCSinkCongestionControl::Homegrown => {
if let Some(congestion_controller) = self.congestion_controller.as_mut() {
congestion_controller.target_bitrate_on_delay += enc.bitrate();
congestion_controller.target_bitrate_on_loss =
congestion_controller.target_bitrate_on_delay;
enc.transceiver.set_property("fec-percentage", 0u32);
} else {
/* If congestion control is disabled, we simply use the highest
* known "safe" value for the bitrate. */
enc.set_bitrate(element, self.cc_info.max_bitrate as i32);
enc.transceiver.set_property("fec-percentage", 50u32);
}
}
_ => enc.transceiver.set_property("fec-percentage", 0u32),
}
self.encoders.push(enc);
@ -1702,21 +1722,21 @@ impl WebRTCSink {
) -> Result<(), WebRTCSinkError> {
let settings = self.settings.lock().unwrap();
let mut state = self.state.lock().unwrap();
let peer_id = peer_id.to_string();
if state.consumers.contains_key(peer_id) {
return Err(WebRTCSinkError::DuplicateConsumerId(peer_id.to_string()));
if state.consumers.contains_key(&peer_id) {
return Err(WebRTCSinkError::DuplicateConsumerId(peer_id));
}
gst::info!(CAT, obj: element, "Adding consumer {}", peer_id);
let pipeline = gst::Pipeline::new(Some(&format!("consumer-pipeline-{}", peer_id)));
let webrtcbin = make_element("webrtcbin", None).map_err(|err| {
WebRTCSinkError::ConsumerPipelineError {
peer_id: peer_id.to_string(),
let webrtcbin = make_element("webrtcbin", Some(&format!("webrtcbin-{}", peer_id)))
.map_err(|err| WebRTCSinkError::ConsumerPipelineError {
peer_id: peer_id.clone(),
details: err.to_string(),
}
})?;
})?;
webrtcbin.set_property_from_str("bundle-policy", "max-bundle");
@ -1728,10 +1748,54 @@ impl WebRTCSink {
webrtcbin.set_property("turn-server", turn_server);
}
match settings.cc_info.heuristic {
WebRTCSinkCongestionControl::GoogleCongestionControl => {
let cc_info = settings.cc_info;
webrtcbin.connect_closure(
"request-aux-sender",
false,
glib::closure!(@watch element, @strong peer_id
=> move |_webrtcbin: gst::Element, _transport: gst::Object| {
let cc = match gst::ElementFactory::make("rtpgccbwe", None) {
Err(err) => {
glib::g_warning!("webrtcsink",
"The `rtpgccbwe` element is not available \
not doing any congestion control: {err:?}"
);
return None;
},
Ok(e) => {
e.set_properties(&[
("min-bitrate", &cc_info.min_bitrate),
("estimated-bitrate", &cc_info.start_bitrate),
("max-bitrate", &cc_info.max_bitrate),
]);
// TODO: Bind properties with @element's
e
}
};
cc.connect_notify(Some("estimated-bitrate"),
glib::clone!(@weak element, @strong peer_id
=> move |bwe, pspec| {
element.imp().set_bitrate(&element, &peer_id,
bwe.property::<u32>(pspec.name()));
}
));
Some(cc)
}),
);
}
}
pipeline.add(&webrtcbin).unwrap();
let element_clone = element.downgrade();
let peer_id_clone = peer_id.to_owned();
let peer_id_clone = peer_id.clone();
webrtcbin.connect("on-ice-candidate", false, move |values| {
if let Some(element) = element_clone.upgrade() {
let this = Self::from_instance(&element);
@ -1748,7 +1812,7 @@ impl WebRTCSink {
});
let element_clone = element.downgrade();
let peer_id_clone = peer_id.to_owned();
let peer_id_clone = peer_id.clone();
webrtcbin.connect_notify(Some("connection-state"), move |webrtcbin, _pspec| {
if let Some(element) = element_clone.upgrade() {
let state =
@ -1779,7 +1843,7 @@ impl WebRTCSink {
});
let element_clone = element.downgrade();
let peer_id_clone = peer_id.to_owned();
let peer_id_clone = peer_id.clone();
webrtcbin.connect_notify(Some("ice-connection-state"), move |webrtcbin, _pspec| {
if let Some(element) = element_clone.upgrade() {
let state = webrtcbin
@ -1826,7 +1890,7 @@ impl WebRTCSink {
});
let element_clone = element.downgrade();
let peer_id_clone = peer_id.to_owned();
let peer_id_clone = peer_id.clone();
webrtcbin.connect_notify(Some("ice-gathering-state"), move |webrtcbin, _pspec| {
let state =
webrtcbin.property::<gst_webrtc::WebRTCICEGatheringState>("ice-gathering-state");
@ -1845,16 +1909,16 @@ impl WebRTCSink {
let mut consumer = Consumer::new(
pipeline.clone(),
webrtcbin.clone(),
peer_id.to_string(),
match settings.cc_heuristic {
WebRTCSinkCongestionControl::Disabled => None,
peer_id.clone(),
match settings.cc_info.heuristic {
WebRTCSinkCongestionControl::Homegrown => Some(CongestionController::new(
peer_id,
settings.min_bitrate,
settings.max_bitrate,
&peer_id,
settings.cc_info.min_bitrate,
settings.cc_info.max_bitrate,
)),
_ => None,
},
settings.max_bitrate,
settings.cc_info,
);
let rtpbin = webrtcbin
@ -1983,7 +2047,7 @@ impl WebRTCSink {
//
// This is completely safe, as we know that by now all conditions are gathered:
// webrtcbin is in the Ready state, and all its transceivers have codec_preferences.
self.negotiate(element, peer_id);
self.negotiate(element, &peer_id);
pipeline.set_state(gst::State::Playing).map_err(|err| {
WebRTCSinkError::ConsumerPipelineError {
@ -2051,6 +2115,32 @@ impl WebRTCSink {
webrtcbin.emit_by_name::<()>("get-stats", &[&None::<gst::Pad>, &promise]);
}
fn set_bitrate(&self, element: &super::WebRTCSink, peer_id: &str, bitrate: u32) {
let mut state = element.imp().state.lock().unwrap();
if let Some(consumer) = state.consumers.get_mut(peer_id) {
let fec_ratio = {
// Start adding some FEC when the bitrate > 2Mbps as we found experimentally
// that it is not worth it below that threshold
if bitrate <= 2_000_000 || consumer.cc_info.max_bitrate <= 2_000_000 {
0f64
} else {
(bitrate as f64 - 2_000_000.)
/ (consumer.cc_info.max_bitrate as f64 - 2_000_000.)
}
};
let fec_percentage = fec_ratio * 50f64;
let encoders_bitrate = ((bitrate as f64) / (1. + (fec_percentage / 100.))) as i32;
for encoder in consumer.encoders.iter_mut() {
encoder.set_bitrate(element, encoders_bitrate);
encoder
.transceiver
.set_property("fec-percentage", fec_percentage as u32);
}
}
}
fn on_remote_description_set(&self, element: &super::WebRTCSink, peer_id: String) {
let mut state = self.state.lock().unwrap();
let mut remove = false;
@ -2105,21 +2195,16 @@ impl WebRTCSink {
let element_clone = element.downgrade();
let webrtcbin = consumer.webrtcbin.downgrade();
let peer_id_clone = peer_id.clone();
task::spawn(async move {
let mut interval =
async_std::stream::interval(std::time::Duration::from_millis(100));
while interval.next().await.is_some() {
let element_clone = element_clone.clone();
let peer_id_clone = peer_id_clone.clone();
if let (Some(webrtcbin), Some(element)) =
(webrtcbin.upgrade(), element_clone.upgrade())
{
element
.imp()
.process_stats(&element, webrtcbin, &peer_id_clone);
element.imp().process_stats(&element, webrtcbin, &peer_id);
} else {
break;
}
@ -2547,7 +2632,7 @@ impl ObjectImpl for WebRTCSink {
"Defines how congestion is controlled, if at all",
WebRTCSinkCongestionControl::static_type(),
DEFAULT_CONGESTION_CONTROL as i32,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_PLAYING,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
),
glib::ParamSpecUInt::new(
"min-bitrate",
@ -2653,48 +2738,27 @@ impl ObjectImpl for WebRTCSink {
}
"congestion-control" => {
let mut settings = self.settings.lock().unwrap();
let new_heuristic = value
settings.cc_info.heuristic = value
.get::<WebRTCSinkCongestionControl>()
.expect("type checked upstream");
if new_heuristic != settings.cc_heuristic {
settings.cc_heuristic = new_heuristic;
let mut state = self.state.lock().unwrap();
for (peer_id, consumer) in state.consumers.iter_mut() {
match new_heuristic {
WebRTCSinkCongestionControl::Disabled => {
consumer.congestion_controller.take();
for encoder in &mut consumer.encoders {
encoder
.set_bitrate(&self.instance(), consumer.max_bitrate as i32);
encoder.transceiver.set_property("fec-percentage", 50u32);
}
consumer.stats_sigid.take();
}
WebRTCSinkCongestionControl::Homegrown => {
let _ = consumer.congestion_controller.insert(
CongestionController::new(
peer_id,
settings.min_bitrate,
settings.max_bitrate,
),
);
}
}
}
}
}
"min-bitrate" => {
let mut settings = self.settings.lock().unwrap();
settings.min_bitrate = value.get::<u32>().expect("type checked upstream");
settings.cc_info.min_bitrate = value.get::<u32>().expect("type checked upstream");
}
"max-bitrate" => {
let mut settings = self.settings.lock().unwrap();
settings.max_bitrate = value.get::<u32>().expect("type checked upstream");
settings.cc_info.max_bitrate = (value.get::<u32>().expect("type checked upstream")
as f32
* if settings.do_fec {
settings.cc_info.max_bitrate as f32 * 1.5
} else {
1.
}) as u32;
}
"start-bitrate" => {
let mut settings = self.settings.lock().unwrap();
settings.start_bitrate = value.get::<u32>().expect("type checked upstream");
settings.cc_info.start_bitrate = value.get::<u32>().expect("type checked upstream");
}
"do-fec" => {
let mut settings = self.settings.lock().unwrap();
@ -2731,7 +2795,7 @@ impl ObjectImpl for WebRTCSink {
}
"congestion-control" => {
let settings = self.settings.lock().unwrap();
settings.cc_heuristic.to_value()
settings.cc_info.heuristic.to_value()
}
"stun-server" => {
let settings = self.settings.lock().unwrap();
@ -2743,15 +2807,17 @@ impl ObjectImpl for WebRTCSink {
}
"min-bitrate" => {
let settings = self.settings.lock().unwrap();
settings.min_bitrate.to_value()
settings.cc_info.min_bitrate.to_value()
}
"max-bitrate" => {
let settings = self.settings.lock().unwrap();
settings.max_bitrate.to_value()
((settings.cc_info.max_bitrate as f32 / if settings.do_fec { 1.5 } else { 1. })
as u32)
.to_value()
}
"start-bitrate" => {
let settings = self.settings.lock().unwrap();
settings.start_bitrate.to_value()
settings.cc_info.start_bitrate.to_value()
}
"do-fec" => {
let settings = self.settings.lock().unwrap();
@ -2873,7 +2939,7 @@ impl ObjectImpl for WebRTCSink {
let this = element.imp();
let settings = this.settings.lock().unwrap();
configure_encoder(&enc, settings.start_bitrate);
configure_encoder(&enc, settings.cc_info.start_bitrate);
// Return false here so that latter handlers get called
Some(false.to_value())

View file

@ -1,6 +1,6 @@
use gst::glib;
use gst::prelude::*;
use gst::subclass::prelude::ObjectSubclassExt;
use gst::subclass::prelude::*;
use std::error::Error;
mod imp;
@ -135,6 +135,8 @@ pub enum WebRTCSinkCongestionControl {
Disabled,
#[enum_value(name = "Homegrown: simple sender-side heuristic", nick = "homegrown")]
Homegrown,
#[enum_value(name = "Google Congestion Control algorithm", nick = "gcc")]
GoogleCongestionControl,
}
#[glib::flags(name = "GstWebRTCSinkMitigationMode")]