Make our "loss based control" algorithm closer to what is defined in [GCC]

As specified in Google Congestion Control we should run the packet loss
estimation algorithm "every time feedback from the receiver is
received".

And, also as defined by GCC, we now have 2 different estimated bitrates,
one for the delay-based controller value and one for the loss-based one,
and we use the minimum value between those 2 as our current estimation.

[GCC]: https://datatracker.ietf.org/doc/html/draft-ietf-rmcat-gcc-02
This commit is contained in:
Thibault Saunier 2022-03-30 15:39:19 +00:00 committed by Mathieu Duponchelle
parent 3c81afa7b2
commit 075a625305
3 changed files with 300 additions and 150 deletions

7
Cargo.lock generated
View file

@ -903,6 +903,12 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9100414882e15fb7feccb4897e5f0ff0ff1ca7d1a86a23208ada4d7a18e6c6c4"
[[package]]
name = "human_bytes"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88a0d4dc39ec942e44c1c306aa196da67f2bd6a30dc7b4a475465c13ccf28817"
[[package]]
name = "idna"
version = "0.2.3"
@ -1847,6 +1853,7 @@ dependencies = [
"gstreamer-sdp",
"gstreamer-video",
"gstreamer-webrtc",
"human_bytes",
"once_cell",
"serde",
"serde_json",

View file

@ -27,6 +27,7 @@ serde = "1"
serde_json = "1"
fastrand = "1.0"
webrtcsink-protocol = { version = "0.1", path="../protocol" }
human_bytes = "0.3.1"
[dev-dependencies]
tracing = { version = "0.1", features = ["log"] }

View file

@ -126,14 +126,21 @@ struct VideoEncoder {
}
struct CongestionController {
/// Overall bitrate target for all video streams.
/// Note: The target bitrate applied is the min of
/// target_bitrate_on_delay and target_bitrate_on_loss
///
/// Bitrate target based on delay factor for all video streams.
/// Hasn't been tested with multiple video streams, but
/// current design is simply to divide bitrate equally.
bitrate_ema: Option<f64>,
target_bitrate_on_delay: i32,
/// Bitrate target based on loss for all video streams.
target_bitrate_on_loss: i32,
/// Exponential moving average, updated when bitrate is
/// decreased, discarded when increased again past last
/// congestion window. Smoothing factor hardcoded.
target_bitrate: i32,
bitrate_ema: Option<f64>,
/// Exponentially weighted moving variance, recursively
/// updated along with bitrate_ema. sqrt'd to obtain standard
/// deviation, used to determine whether to increase bitrate
@ -162,11 +169,23 @@ enum CongestionControlOp {
/// Don't update target bitrate
Hold,
/// Decrease target bitrate
Decrease(f64),
Decrease {
factor: f64,
#[allow(dead_code)]
reason: String, // for Debug
},
/// Increase target bitrate, either additively or multiplicatively
Increase(IncreaseType),
}
#[derive(Debug, Clone, Copy)]
enum ControllerType {
// Running the "delay-based controller"
Delay,
// Running the "loss based controller"
Loss,
}
struct Consumer {
pipeline: gst::Pipeline,
webrtcbin: gst::Element,
@ -179,6 +198,8 @@ struct Consumer {
stats: gst::Structure,
max_bitrate: u32,
stats_sigid: Option<glib::SignalHandlerId>,
}
#[derive(PartialEq)]
@ -215,6 +236,7 @@ fn create_navigation_event<N: IsA<gst_video::Navigation>>(sink: &N, msg: &str) {
gst_error!(CAT, "Invalid navigation event: {:?}", msg);
}
}
/// Simple utility for tearing down a pipeline cleanly
struct PipelineWrapper(gst::Pipeline);
@ -465,20 +487,6 @@ fn setup_encoding(
Ok((enc, conv_filter, pay))
}
fn lookup_remote_inbound_rtp_stats(stats: &gst::StructureRef) -> Option<gst::Structure> {
for (_, field_value) in stats {
if let Ok(s) = field_value.get::<gst::Structure>() {
if let Ok(type_) = s.get::<gst_webrtc::WebRTCStatsType>("type") {
if type_ == gst_webrtc::WebRTCStatsType::RemoteInboundRtp {
return Some(s);
}
}
}
}
None
}
fn lookup_transport_stats(stats: &gst::StructureRef) -> Option<gst::Structure> {
for (_, field_value) in stats {
if let Ok(s) = field_value.get::<gst::Structure>() {
@ -627,7 +635,8 @@ impl VideoEncoder {
impl CongestionController {
fn new(peer_id: &str, min_bitrate: u32, max_bitrate: u32) -> Self {
Self {
target_bitrate: 0,
target_bitrate_on_delay: 0,
target_bitrate_on_loss: 0,
bitrate_ema: None,
bitrate_emvar: 0.,
last_update_time: None,
@ -637,19 +646,21 @@ impl CongestionController {
}
}
fn update(
fn update_delay(
&mut self,
element: &super::WebRTCSink,
twcc_stats: &gst::StructureRef,
rtt: f64,
) -> CongestionControlOp {
let target_bitrate = self.target_bitrate as f64;
let target_bitrate = f64::min(
self.target_bitrate_on_delay as f64,
self.target_bitrate_on_loss as f64,
);
// Unwrap, all those fields must be there or there's been an API
// break, which qualifies as programming error
let bitrate_sent = twcc_stats.get::<u32>("bitrate-sent").unwrap();
let bitrate_recv = twcc_stats.get::<u32>("bitrate-recv").unwrap();
let delta_of_delta = twcc_stats.get::<i64>("avg-delta-of-delta").unwrap();
let loss_percentage = twcc_stats.get::<f64>("packet-loss-pct").unwrap();
let sent_minus_received = bitrate_sent.saturating_sub(bitrate_recv);
@ -665,61 +676,22 @@ impl CongestionController {
);
if delay_factor > 0.1 {
CongestionControlOp::Decrease(if delay_factor < 0.64 {
gst_trace!(
CAT,
obj: element,
"consumer {}: low delay factor {}",
self.peer_id,
delay_factor,
);
0.96
let (factor, reason) = if delay_factor < 0.64 {
(0.96, format!("low delay factor {}", delay_factor))
} else {
gst_trace!(
CAT,
obj: element,
"consumer {}: High delay factor",
self.peer_id
);
delay_factor.sqrt().sqrt().clamp(0.8, 0.96)
})
} else if delta_of_delta > 1000000 {
CongestionControlOp::Decrease(if loss_percentage < 10. {
gst_trace!(
CAT,
obj: element,
"consumer {}: moderate loss high delta",
self.peer_id
);
0.97
} else {
gst_log!(
CAT,
obj: element,
"consumer: {}: high loss high delta",
self.peer_id
);
((100. - loss_percentage) / 100.).clamp(0.7, 0.98)
})
} else if loss_percentage > 10. {
CongestionControlOp::Decrease(
((100. - (0.5 * loss_percentage)) / 100.).clamp(0.7, 0.98),
)
} else if loss_percentage > 2. {
gst_trace!(
CAT,
obj: element,
"consumer {}: moderate loss",
self.peer_id
);
CongestionControlOp::Hold
(
delay_factor.sqrt().sqrt().clamp(0.8, 0.96),
format!("High delay factor {}", delay_factor),
)
};
CongestionControlOp::Decrease { factor, reason }
} else if delta_of_delta > 1_000_000 {
CongestionControlOp::Decrease {
factor: 0.97,
reason: format!("High delta: {}", delta_of_delta),
}
} else {
gst_trace!(
CAT,
obj: element,
"consumer {}: no detected congestion",
self.peer_id
);
CongestionControlOp::Increase(if let Some(ema) = self.bitrate_ema {
let bitrate_stdev = self.bitrate_emvar.sqrt();
@ -800,81 +772,187 @@ impl CongestionController {
}
}
fn clamp_bitrate(&mut self, bitrate: i32, n_encoders: i32) {
self.target_bitrate = bitrate.clamp(
self.min_bitrate as i32 * n_encoders,
self.max_bitrate as i32 * n_encoders,
);
fn clamp_bitrate(&mut self, bitrate: i32, n_encoders: i32, controller_type: ControllerType) {
match controller_type {
ControllerType::Loss => {
self.target_bitrate_on_loss = bitrate.clamp(
self.min_bitrate as i32 * n_encoders,
self.max_bitrate as i32 * n_encoders,
)
}
ControllerType::Delay => {
self.target_bitrate_on_delay = bitrate.clamp(
self.min_bitrate as i32 * n_encoders,
self.max_bitrate as i32 * n_encoders,
)
}
}
}
fn control(
fn get_remote_inbound_stats(&self, stats: &gst::StructureRef) -> Vec<gst::Structure> {
let mut inbound_rtp_stats: Vec<gst::Structure> = Default::default();
for (_, field_value) in stats {
if let Ok(s) = field_value.get::<gst::Structure>() {
if let Ok(type_) = s.get::<gst_webrtc::WebRTCStatsType>("type") {
if type_ == gst_webrtc::WebRTCStatsType::RemoteInboundRtp {
inbound_rtp_stats.push(s);
}
}
}
}
inbound_rtp_stats
}
fn lookup_rtt(&self, stats: &gst::StructureRef) -> f64 {
let inbound_rtp_stats = self.get_remote_inbound_stats(stats);
let mut rtt = 0.;
let mut n_rtts = 0u64;
for inbound_stat in &inbound_rtp_stats {
if let Err(err) = (|| -> Result<(), gst::structure::GetError> {
rtt += inbound_stat.get::<f64>("round-trip-time")?;
n_rtts += 1;
Ok(())
})() {
gst_debug!(CAT, "{:?}", err);
}
}
rtt /= f64::max(1., n_rtts as f64);
gst_log!(CAT, "Round trip time: {}", rtt);
rtt
}
fn loss_control(
&mut self,
element: &super::WebRTCSink,
stats: &gst::StructureRef,
encoders: &mut Vec<VideoEncoder>,
) {
let n_encoders = encoders.len() as i32;
let loss_percentage = stats.get::<f64>("packet-loss-pct").unwrap();
let rtt = lookup_remote_inbound_rtp_stats(stats)
.and_then(|s| s.get::<f64>("round-trip-time").ok())
.unwrap_or(0.);
self.apply_control_op(
element,
encoders,
if loss_percentage > 10. {
CongestionControlOp::Decrease {
factor: ((100. - (0.5 * loss_percentage)) / 100.).clamp(0.7, 0.98),
reason: format!("High loss: {}", loss_percentage),
}
} else if loss_percentage > 2. {
CongestionControlOp::Hold
} else {
CongestionControlOp::Increase(IncreaseType::Multiplicative(1.05))
},
ControllerType::Loss,
);
}
fn delay_control(
&mut self,
element: &super::WebRTCSink,
stats: &gst::StructureRef,
encoders: &mut Vec<VideoEncoder>,
) {
if let Some(twcc_stats) = lookup_transport_stats(stats).and_then(|transport_stats| {
transport_stats.get::<gst::Structure>("gst-twcc-stats").ok()
}) {
let control_op = self.update(element, &twcc_stats, rtt);
let op = self.update_delay(element, &twcc_stats, self.lookup_rtt(stats));
self.apply_control_op(element, encoders, op, ControllerType::Delay);
}
}
gst_trace!(
CAT,
obj: element,
"consumer {}: applying congestion control operation {:?}",
self.peer_id,
control_op
);
fn apply_control_op(
&mut self,
element: &super::WebRTCSink,
encoders: &mut Vec<VideoEncoder>,
control_op: CongestionControlOp,
controller_type: ControllerType,
) {
gst_trace!(
CAT,
obj: element,
"consumer {}: applying congestion control operation {:?}",
self.peer_id,
control_op
);
match control_op {
CongestionControlOp::Hold => (),
CongestionControlOp::Increase(IncreaseType::Additive(value)) => {
self.clamp_bitrate(self.target_bitrate + value as i32, n_encoders);
}
CongestionControlOp::Increase(IncreaseType::Multiplicative(factor)) => {
self.clamp_bitrate((self.target_bitrate as f64 * factor) as i32, n_encoders);
}
CongestionControlOp::Decrease(factor) => {
self.clamp_bitrate((self.target_bitrate as f64 * factor) as i32, n_encoders);
let n_encoders = encoders.len() as i32;
let prev_bitrate = i32::min(self.target_bitrate_on_delay, self.target_bitrate_on_loss);
match &control_op {
CongestionControlOp::Hold => {}
CongestionControlOp::Increase(IncreaseType::Additive(value)) => {
self.clamp_bitrate(
self.target_bitrate_on_delay + *value as i32,
n_encoders,
controller_type,
);
}
CongestionControlOp::Increase(IncreaseType::Multiplicative(factor)) => {
self.clamp_bitrate(
(self.target_bitrate_on_delay as f64 * factor) as i32,
n_encoders,
controller_type,
);
}
CongestionControlOp::Decrease { factor, .. } => {
self.clamp_bitrate(
(self.target_bitrate_on_delay as f64 * factor) as i32,
n_encoders,
controller_type,
);
if let ControllerType::Delay = controller_type {
// Smoothing factor
let alpha = 0.75;
if let Some(ema) = self.bitrate_ema {
let sigma: f64 = (self.target_bitrate as f64) - ema;
let sigma: f64 = (self.target_bitrate_on_delay as f64) - ema;
self.bitrate_ema = Some(ema + (alpha * sigma));
self.bitrate_emvar =
(1. - alpha) * (self.bitrate_emvar + alpha * sigma.powi(2));
} else {
self.bitrate_ema = Some(self.target_bitrate as f64);
self.bitrate_ema = Some(self.target_bitrate_on_delay as f64);
self.bitrate_emvar = 0.;
}
}
}
}
let target_bitrate = self.target_bitrate / n_encoders;
let target_bitrate =
i32::min(self.target_bitrate_on_delay, self.target_bitrate_on_loss).clamp(
self.min_bitrate as i32 * n_encoders,
self.max_bitrate as i32 * n_encoders,
) / n_encoders;
let fec_ratio = {
if target_bitrate <= 2000000 || self.max_bitrate <= 2000000 {
0f64
} else {
(target_bitrate as f64 - 2000000f64) / (self.max_bitrate as f64 - 2000000f64)
}
};
if target_bitrate != prev_bitrate {
gst_info!(
CAT,
"{:?} {} => {}",
control_op,
human_bytes::human_bytes(prev_bitrate),
human_bytes::human_bytes(target_bitrate)
);
}
let fec_percentage = (fec_ratio * 50f64) as u32;
for encoder in encoders.iter_mut() {
encoder.set_bitrate(element, target_bitrate);
encoder
.transceiver
.set_property("fec-percentage", fec_percentage);
let fec_ratio = {
if target_bitrate <= 2000000 || self.max_bitrate <= 2000000 {
0f64
} else {
(target_bitrate as f64 - 2000000f64) / (self.max_bitrate as f64 - 2000000f64)
}
};
let fec_percentage = (fec_ratio * 50f64) as u32;
for encoder in encoders.iter_mut() {
encoder.set_bitrate(element, target_bitrate);
encoder
.transceiver
.set_property("fec-percentage", fec_percentage);
}
}
}
@ -953,6 +1031,27 @@ impl State {
}
impl Consumer {
fn new(
pipeline: gst::Pipeline,
webrtcbin: gst::Element,
peer_id: String,
congestion_controller: Option<CongestionController>,
max_bitrate: u32,
) -> Self {
Self {
pipeline,
webrtcbin,
peer_id,
congestion_controller,
max_bitrate,
sdp: None,
stats: gst::Structure::new_empty("application/x-webrtc-stats"),
webrtc_pads: HashMap::new(),
encoders: Vec::new(),
stats_sigid: None,
}
}
fn gather_stats(&self) -> gst::Structure {
let mut ret = self.stats.to_owned();
@ -1140,7 +1239,9 @@ impl Consumer {
);
if let Some(congestion_controller) = self.congestion_controller.as_mut() {
congestion_controller.target_bitrate += enc.bitrate();
congestion_controller.target_bitrate_on_delay += enc.bitrate();
congestion_controller.target_bitrate_on_loss =
congestion_controller.target_bitrate_on_delay;
enc.transceiver.set_property("fec-percentage", 0u32);
} else {
/* If congestion control is disabled, we simply use the highest
@ -1658,12 +1759,11 @@ impl WebRTCSink {
}
});
let mut consumer = Consumer {
pipeline: pipeline.clone(),
webrtcbin: webrtcbin.clone(),
webrtc_pads: HashMap::new(),
peer_id: peer_id.to_string(),
congestion_controller: match settings.cc_heuristic {
let mut consumer = Consumer::new(
pipeline.clone(),
webrtcbin.clone(),
peer_id.to_string(),
match settings.cc_heuristic {
WebRTCSinkCongestionControl::Disabled => None,
WebRTCSinkCongestionControl::Homegrown => Some(CongestionController::new(
peer_id,
@ -1671,11 +1771,39 @@ impl WebRTCSink {
settings.max_bitrate,
)),
},
encoders: Vec::new(),
sdp: None,
stats: gst::Structure::new_empty("application/x-webrtc-stats"),
max_bitrate: settings.max_bitrate,
};
settings.max_bitrate,
);
let rtpbin = webrtcbin
.dynamic_cast_ref::<gst::ChildProxy>()
.unwrap()
.child_by_name("rtpbin")
.unwrap();
if consumer.congestion_controller.is_some() {
let peer_id_str = peer_id.to_string();
if consumer.stats_sigid.is_none() {
consumer.stats_sigid = Some(rtpbin.connect_closure("on-new-ssrc", true,
glib::closure!(@weak-allow-none element, @weak-allow-none webrtcbin
=> move |rtpbin: gst::Object, session_id: u32, _src: u32| {
let session = rtpbin.emit_by_name::<gst::Element>("get-session", &[&session_id]);
let element = element.expect("on-new-ssrc emited when webrtcsink has been disposed?");
let webrtcbin = webrtcbin.unwrap();
let mut state = element.imp().state.lock().unwrap();
if let Some(mut consumer) = state.consumers.get_mut(&peer_id_str) {
consumer.stats_sigid = Some(session.connect_notify(Some("twcc-stats"),
glib::clone!(@strong peer_id_str, @weak webrtcbin, @weak element => @default-panic, move |sess, pspec| {
// Run the Loss-based control algortithm on new peer TWCC feedbacks
element.imp().process_loss_stats(&element, &peer_id_str, &sess.property::<gst::Structure>(pspec.name()));
})
));
}
})
));
}
}
state
.streams
@ -1805,22 +1933,41 @@ impl WebRTCSink {
Ok(())
}
fn process_webrtcbin_stats(
fn process_loss_stats(
&self,
element: &super::WebRTCSink,
peer_id: &str,
stats: &gst::StructureRef,
stats: &gst::Structure,
) {
let mut state = self.state.lock().unwrap();
if let Some(consumer) = state.consumers.get_mut(peer_id) {
let mut state = element.imp().state.lock().unwrap();
if let Some(mut consumer) = state.consumers.get_mut(peer_id) {
if let Some(congestion_controller) = consumer.congestion_controller.as_mut() {
congestion_controller.control(element, stats, &mut consumer.encoders);
congestion_controller.loss_control(&element, stats, &mut consumer.encoders);
}
consumer.stats = stats.to_owned();
}
}
fn process_stats(&self, element: &super::WebRTCSink, webrtcbin: gst::Element, peer_id: &str) {
let peer_id = peer_id.to_string();
let promise = gst::Promise::with_change_func(
glib::clone!(@strong peer_id, @weak element => move |reply| {
if let Ok(Some(stats)) = reply {
let mut state = element.imp().state.lock().unwrap();
if let Some(mut consumer) = state.consumers.get_mut(&peer_id) {
if let Some(congestion_controller) = consumer.congestion_controller.as_mut() {
congestion_controller.delay_control(&element, stats, &mut consumer.encoders,);
}
consumer.stats = stats.to_owned();
}
}
}),
);
webrtcbin.emit_by_name::<()>("get-stats", &[&None::<gst::Pad>, &promise]);
}
fn on_remote_description_set(&self, element: &super::WebRTCSink, peer_id: String) {
let mut state = self.state.lock().unwrap();
let mut remove = false;
@ -1874,18 +2021,12 @@ impl WebRTCSink {
while interval.next().await.is_some() {
let element_clone = element_clone.clone();
let peer_id_clone = peer_id_clone.clone();
if let Some(webrtcbin) = webrtcbin.upgrade() {
let promise = gst::Promise::with_change_func(move |reply| {
if let Some(element) = element_clone.upgrade() {
let this = Self::from_instance(&element);
if let Ok(Some(stats)) = reply {
this.process_webrtcbin_stats(&element, &peer_id_clone, stats);
}
}
});
webrtcbin.emit_by_name::<()>("get-stats", &[&None::<gst::Pad>, &promise]);
if let (Some(webrtcbin), Some(element)) =
(webrtcbin.upgrade(), element_clone.upgrade())
{
element
.imp()
.process_stats(&element, webrtcbin, &peer_id_clone);
} else {
break;
}
@ -2426,6 +2567,7 @@ impl ObjectImpl for WebRTCSink {
.set_bitrate(&self.instance(), consumer.max_bitrate as i32);
encoder.transceiver.set_property("fec-percentage", 50u32);
}
consumer.stats_sigid.take();
}
WebRTCSinkCongestionControl::Homegrown => {
let _ = consumer.congestion_controller.insert(