From 075a625305a625ce5be7024f54525dc29f319556 Mon Sep 17 00:00:00 2001 From: Thibault Saunier Date: Wed, 30 Mar 2022 15:39:19 +0000 Subject: [PATCH] Make our "loss based control" algorithm closer to what is defined in [GCC] As specified in Google Congestion Control we should run the packet loss estimation algorithm "every time feedback from the receiver is received". And, also as defined by GCC, we now have 2 different estimated bitrates, one for the delay-based controller value and one for the loss-based one, and we use the minimum value between those 2 as our current estimation. [GCC]: https://datatracker.ietf.org/doc/html/draft-ietf-rmcat-gcc-02 --- Cargo.lock | 7 + plugins/Cargo.toml | 1 + plugins/src/webrtcsink/imp.rs | 442 ++++++++++++++++++++++------------ 3 files changed, 300 insertions(+), 150 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc59bae8..6b106632 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -903,6 +903,12 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9100414882e15fb7feccb4897e5f0ff0ff1ca7d1a86a23208ada4d7a18e6c6c4" +[[package]] +name = "human_bytes" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a0d4dc39ec942e44c1c306aa196da67f2bd6a30dc7b4a475465c13ccf28817" + [[package]] name = "idna" version = "0.2.3" @@ -1847,6 +1853,7 @@ dependencies = [ "gstreamer-sdp", "gstreamer-video", "gstreamer-webrtc", + "human_bytes", "once_cell", "serde", "serde_json", diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index b74cfa88..1038bc94 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -27,6 +27,7 @@ serde = "1" serde_json = "1" fastrand = "1.0" webrtcsink-protocol = { version = "0.1", path="../protocol" } +human_bytes = "0.3.1" [dev-dependencies] tracing = { version = "0.1", features = ["log"] } diff --git a/plugins/src/webrtcsink/imp.rs b/plugins/src/webrtcsink/imp.rs index 10b377c0..1eca95dd 100644 --- a/plugins/src/webrtcsink/imp.rs +++ b/plugins/src/webrtcsink/imp.rs @@ -126,14 +126,21 @@ struct VideoEncoder { } struct CongestionController { - /// Overall bitrate target for all video streams. + /// Note: The target bitrate applied is the min of + /// target_bitrate_on_delay and target_bitrate_on_loss + /// + /// Bitrate target based on delay factor for all video streams. /// Hasn't been tested with multiple video streams, but /// current design is simply to divide bitrate equally. - bitrate_ema: Option, + target_bitrate_on_delay: i32, + + /// Bitrate target based on loss for all video streams. + target_bitrate_on_loss: i32, + /// Exponential moving average, updated when bitrate is /// decreased, discarded when increased again past last /// congestion window. Smoothing factor hardcoded. - target_bitrate: i32, + bitrate_ema: Option, /// Exponentially weighted moving variance, recursively /// updated along with bitrate_ema. sqrt'd to obtain standard /// deviation, used to determine whether to increase bitrate @@ -162,11 +169,23 @@ enum CongestionControlOp { /// Don't update target bitrate Hold, /// Decrease target bitrate - Decrease(f64), + Decrease { + factor: f64, + #[allow(dead_code)] + reason: String, // for Debug + }, /// Increase target bitrate, either additively or multiplicatively Increase(IncreaseType), } +#[derive(Debug, Clone, Copy)] +enum ControllerType { + // Running the "delay-based controller" + Delay, + // Running the "loss based controller" + Loss, +} + struct Consumer { pipeline: gst::Pipeline, webrtcbin: gst::Element, @@ -179,6 +198,8 @@ struct Consumer { stats: gst::Structure, max_bitrate: u32, + + stats_sigid: Option, } #[derive(PartialEq)] @@ -215,6 +236,7 @@ fn create_navigation_event>(sink: &N, msg: &str) { gst_error!(CAT, "Invalid navigation event: {:?}", msg); } } + /// Simple utility for tearing down a pipeline cleanly struct PipelineWrapper(gst::Pipeline); @@ -465,20 +487,6 @@ fn setup_encoding( Ok((enc, conv_filter, pay)) } -fn lookup_remote_inbound_rtp_stats(stats: &gst::StructureRef) -> Option { - for (_, field_value) in stats { - if let Ok(s) = field_value.get::() { - if let Ok(type_) = s.get::("type") { - if type_ == gst_webrtc::WebRTCStatsType::RemoteInboundRtp { - return Some(s); - } - } - } - } - - None -} - fn lookup_transport_stats(stats: &gst::StructureRef) -> Option { for (_, field_value) in stats { if let Ok(s) = field_value.get::() { @@ -627,7 +635,8 @@ impl VideoEncoder { impl CongestionController { fn new(peer_id: &str, min_bitrate: u32, max_bitrate: u32) -> Self { Self { - target_bitrate: 0, + target_bitrate_on_delay: 0, + target_bitrate_on_loss: 0, bitrate_ema: None, bitrate_emvar: 0., last_update_time: None, @@ -637,19 +646,21 @@ impl CongestionController { } } - fn update( + fn update_delay( &mut self, element: &super::WebRTCSink, twcc_stats: &gst::StructureRef, rtt: f64, ) -> CongestionControlOp { - let target_bitrate = self.target_bitrate as f64; + let target_bitrate = f64::min( + self.target_bitrate_on_delay as f64, + self.target_bitrate_on_loss as f64, + ); // Unwrap, all those fields must be there or there's been an API // break, which qualifies as programming error let bitrate_sent = twcc_stats.get::("bitrate-sent").unwrap(); let bitrate_recv = twcc_stats.get::("bitrate-recv").unwrap(); let delta_of_delta = twcc_stats.get::("avg-delta-of-delta").unwrap(); - let loss_percentage = twcc_stats.get::("packet-loss-pct").unwrap(); let sent_minus_received = bitrate_sent.saturating_sub(bitrate_recv); @@ -665,61 +676,22 @@ impl CongestionController { ); if delay_factor > 0.1 { - CongestionControlOp::Decrease(if delay_factor < 0.64 { - gst_trace!( - CAT, - obj: element, - "consumer {}: low delay factor {}", - self.peer_id, - delay_factor, - ); - 0.96 + let (factor, reason) = if delay_factor < 0.64 { + (0.96, format!("low delay factor {}", delay_factor)) } else { - gst_trace!( - CAT, - obj: element, - "consumer {}: High delay factor", - self.peer_id - ); - delay_factor.sqrt().sqrt().clamp(0.8, 0.96) - }) - } else if delta_of_delta > 1000000 { - CongestionControlOp::Decrease(if loss_percentage < 10. { - gst_trace!( - CAT, - obj: element, - "consumer {}: moderate loss high delta", - self.peer_id - ); - 0.97 - } else { - gst_log!( - CAT, - obj: element, - "consumer: {}: high loss high delta", - self.peer_id - ); - ((100. - loss_percentage) / 100.).clamp(0.7, 0.98) - }) - } else if loss_percentage > 10. { - CongestionControlOp::Decrease( - ((100. - (0.5 * loss_percentage)) / 100.).clamp(0.7, 0.98), - ) - } else if loss_percentage > 2. { - gst_trace!( - CAT, - obj: element, - "consumer {}: moderate loss", - self.peer_id - ); - CongestionControlOp::Hold + ( + delay_factor.sqrt().sqrt().clamp(0.8, 0.96), + format!("High delay factor {}", delay_factor), + ) + }; + + CongestionControlOp::Decrease { factor, reason } + } else if delta_of_delta > 1_000_000 { + CongestionControlOp::Decrease { + factor: 0.97, + reason: format!("High delta: {}", delta_of_delta), + } } else { - gst_trace!( - CAT, - obj: element, - "consumer {}: no detected congestion", - self.peer_id - ); CongestionControlOp::Increase(if let Some(ema) = self.bitrate_ema { let bitrate_stdev = self.bitrate_emvar.sqrt(); @@ -800,81 +772,187 @@ impl CongestionController { } } - fn clamp_bitrate(&mut self, bitrate: i32, n_encoders: i32) { - self.target_bitrate = bitrate.clamp( - self.min_bitrate as i32 * n_encoders, - self.max_bitrate as i32 * n_encoders, - ); + fn clamp_bitrate(&mut self, bitrate: i32, n_encoders: i32, controller_type: ControllerType) { + match controller_type { + ControllerType::Loss => { + self.target_bitrate_on_loss = bitrate.clamp( + self.min_bitrate as i32 * n_encoders, + self.max_bitrate as i32 * n_encoders, + ) + } + + ControllerType::Delay => { + self.target_bitrate_on_delay = bitrate.clamp( + self.min_bitrate as i32 * n_encoders, + self.max_bitrate as i32 * n_encoders, + ) + } + } } - fn control( + fn get_remote_inbound_stats(&self, stats: &gst::StructureRef) -> Vec { + let mut inbound_rtp_stats: Vec = Default::default(); + for (_, field_value) in stats { + if let Ok(s) = field_value.get::() { + if let Ok(type_) = s.get::("type") { + if type_ == gst_webrtc::WebRTCStatsType::RemoteInboundRtp { + inbound_rtp_stats.push(s); + } + } + } + } + + inbound_rtp_stats + } + + fn lookup_rtt(&self, stats: &gst::StructureRef) -> f64 { + let inbound_rtp_stats = self.get_remote_inbound_stats(stats); + let mut rtt = 0.; + let mut n_rtts = 0u64; + for inbound_stat in &inbound_rtp_stats { + if let Err(err) = (|| -> Result<(), gst::structure::GetError> { + rtt += inbound_stat.get::("round-trip-time")?; + n_rtts += 1; + + Ok(()) + })() { + gst_debug!(CAT, "{:?}", err); + } + } + + rtt /= f64::max(1., n_rtts as f64); + + gst_log!(CAT, "Round trip time: {}", rtt); + + rtt + } + + fn loss_control( &mut self, element: &super::WebRTCSink, stats: &gst::StructureRef, encoders: &mut Vec, ) { - let n_encoders = encoders.len() as i32; + let loss_percentage = stats.get::("packet-loss-pct").unwrap(); - let rtt = lookup_remote_inbound_rtp_stats(stats) - .and_then(|s| s.get::("round-trip-time").ok()) - .unwrap_or(0.); + self.apply_control_op( + element, + encoders, + if loss_percentage > 10. { + CongestionControlOp::Decrease { + factor: ((100. - (0.5 * loss_percentage)) / 100.).clamp(0.7, 0.98), + reason: format!("High loss: {}", loss_percentage), + } + } else if loss_percentage > 2. { + CongestionControlOp::Hold + } else { + CongestionControlOp::Increase(IncreaseType::Multiplicative(1.05)) + }, + ControllerType::Loss, + ); + } + fn delay_control( + &mut self, + element: &super::WebRTCSink, + stats: &gst::StructureRef, + encoders: &mut Vec, + ) { if let Some(twcc_stats) = lookup_transport_stats(stats).and_then(|transport_stats| { transport_stats.get::("gst-twcc-stats").ok() }) { - let control_op = self.update(element, &twcc_stats, rtt); + let op = self.update_delay(element, &twcc_stats, self.lookup_rtt(stats)); + self.apply_control_op(element, encoders, op, ControllerType::Delay); + } + } - gst_trace!( - CAT, - obj: element, - "consumer {}: applying congestion control operation {:?}", - self.peer_id, - control_op - ); + fn apply_control_op( + &mut self, + element: &super::WebRTCSink, + encoders: &mut Vec, + control_op: CongestionControlOp, + controller_type: ControllerType, + ) { + gst_trace!( + CAT, + obj: element, + "consumer {}: applying congestion control operation {:?}", + self.peer_id, + control_op + ); - match control_op { - CongestionControlOp::Hold => (), - CongestionControlOp::Increase(IncreaseType::Additive(value)) => { - self.clamp_bitrate(self.target_bitrate + value as i32, n_encoders); - } - CongestionControlOp::Increase(IncreaseType::Multiplicative(factor)) => { - self.clamp_bitrate((self.target_bitrate as f64 * factor) as i32, n_encoders); - } - CongestionControlOp::Decrease(factor) => { - self.clamp_bitrate((self.target_bitrate as f64 * factor) as i32, n_encoders); + let n_encoders = encoders.len() as i32; + let prev_bitrate = i32::min(self.target_bitrate_on_delay, self.target_bitrate_on_loss); + match &control_op { + CongestionControlOp::Hold => {} + CongestionControlOp::Increase(IncreaseType::Additive(value)) => { + self.clamp_bitrate( + self.target_bitrate_on_delay + *value as i32, + n_encoders, + controller_type, + ); + } + CongestionControlOp::Increase(IncreaseType::Multiplicative(factor)) => { + self.clamp_bitrate( + (self.target_bitrate_on_delay as f64 * factor) as i32, + n_encoders, + controller_type, + ); + } + CongestionControlOp::Decrease { factor, .. } => { + self.clamp_bitrate( + (self.target_bitrate_on_delay as f64 * factor) as i32, + n_encoders, + controller_type, + ); + if let ControllerType::Delay = controller_type { // Smoothing factor let alpha = 0.75; if let Some(ema) = self.bitrate_ema { - let sigma: f64 = (self.target_bitrate as f64) - ema; + let sigma: f64 = (self.target_bitrate_on_delay as f64) - ema; self.bitrate_ema = Some(ema + (alpha * sigma)); self.bitrate_emvar = (1. - alpha) * (self.bitrate_emvar + alpha * sigma.powi(2)); } else { - self.bitrate_ema = Some(self.target_bitrate as f64); + self.bitrate_ema = Some(self.target_bitrate_on_delay as f64); self.bitrate_emvar = 0.; } } } + } - let target_bitrate = self.target_bitrate / n_encoders; + let target_bitrate = + i32::min(self.target_bitrate_on_delay, self.target_bitrate_on_loss).clamp( + self.min_bitrate as i32 * n_encoders, + self.max_bitrate as i32 * n_encoders, + ) / n_encoders; - let fec_ratio = { - if target_bitrate <= 2000000 || self.max_bitrate <= 2000000 { - 0f64 - } else { - (target_bitrate as f64 - 2000000f64) / (self.max_bitrate as f64 - 2000000f64) - } - }; + if target_bitrate != prev_bitrate { + gst_info!( + CAT, + "{:?} {} => {}", + control_op, + human_bytes::human_bytes(prev_bitrate), + human_bytes::human_bytes(target_bitrate) + ); + } - let fec_percentage = (fec_ratio * 50f64) as u32; - - for encoder in encoders.iter_mut() { - encoder.set_bitrate(element, target_bitrate); - encoder - .transceiver - .set_property("fec-percentage", fec_percentage); + let fec_ratio = { + if target_bitrate <= 2000000 || self.max_bitrate <= 2000000 { + 0f64 + } else { + (target_bitrate as f64 - 2000000f64) / (self.max_bitrate as f64 - 2000000f64) } + }; + + let fec_percentage = (fec_ratio * 50f64) as u32; + + for encoder in encoders.iter_mut() { + encoder.set_bitrate(element, target_bitrate); + encoder + .transceiver + .set_property("fec-percentage", fec_percentage); } } } @@ -953,6 +1031,27 @@ impl State { } impl Consumer { + fn new( + pipeline: gst::Pipeline, + webrtcbin: gst::Element, + peer_id: String, + congestion_controller: Option, + max_bitrate: u32, + ) -> Self { + Self { + pipeline, + webrtcbin, + peer_id, + congestion_controller, + max_bitrate, + sdp: None, + stats: gst::Structure::new_empty("application/x-webrtc-stats"), + webrtc_pads: HashMap::new(), + encoders: Vec::new(), + stats_sigid: None, + } + } + fn gather_stats(&self) -> gst::Structure { let mut ret = self.stats.to_owned(); @@ -1140,7 +1239,9 @@ impl Consumer { ); if let Some(congestion_controller) = self.congestion_controller.as_mut() { - congestion_controller.target_bitrate += enc.bitrate(); + congestion_controller.target_bitrate_on_delay += enc.bitrate(); + congestion_controller.target_bitrate_on_loss = + congestion_controller.target_bitrate_on_delay; enc.transceiver.set_property("fec-percentage", 0u32); } else { /* If congestion control is disabled, we simply use the highest @@ -1658,12 +1759,11 @@ impl WebRTCSink { } }); - let mut consumer = Consumer { - pipeline: pipeline.clone(), - webrtcbin: webrtcbin.clone(), - webrtc_pads: HashMap::new(), - peer_id: peer_id.to_string(), - congestion_controller: match settings.cc_heuristic { + let mut consumer = Consumer::new( + pipeline.clone(), + webrtcbin.clone(), + peer_id.to_string(), + match settings.cc_heuristic { WebRTCSinkCongestionControl::Disabled => None, WebRTCSinkCongestionControl::Homegrown => Some(CongestionController::new( peer_id, @@ -1671,11 +1771,39 @@ impl WebRTCSink { settings.max_bitrate, )), }, - encoders: Vec::new(), - sdp: None, - stats: gst::Structure::new_empty("application/x-webrtc-stats"), - max_bitrate: settings.max_bitrate, - }; + settings.max_bitrate, + ); + + let rtpbin = webrtcbin + .dynamic_cast_ref::() + .unwrap() + .child_by_name("rtpbin") + .unwrap(); + + if consumer.congestion_controller.is_some() { + let peer_id_str = peer_id.to_string(); + if consumer.stats_sigid.is_none() { + consumer.stats_sigid = Some(rtpbin.connect_closure("on-new-ssrc", true, + glib::closure!(@weak-allow-none element, @weak-allow-none webrtcbin + => move |rtpbin: gst::Object, session_id: u32, _src: u32| { + let session = rtpbin.emit_by_name::("get-session", &[&session_id]); + + let element = element.expect("on-new-ssrc emited when webrtcsink has been disposed?"); + let webrtcbin = webrtcbin.unwrap(); + let mut state = element.imp().state.lock().unwrap(); + if let Some(mut consumer) = state.consumers.get_mut(&peer_id_str) { + + consumer.stats_sigid = Some(session.connect_notify(Some("twcc-stats"), + glib::clone!(@strong peer_id_str, @weak webrtcbin, @weak element => @default-panic, move |sess, pspec| { + // Run the Loss-based control algortithm on new peer TWCC feedbacks + element.imp().process_loss_stats(&element, &peer_id_str, &sess.property::(pspec.name())); + }) + )); + } + }) + )); + } + } state .streams @@ -1805,22 +1933,41 @@ impl WebRTCSink { Ok(()) } - fn process_webrtcbin_stats( + fn process_loss_stats( &self, element: &super::WebRTCSink, peer_id: &str, - stats: &gst::StructureRef, + stats: &gst::Structure, ) { - let mut state = self.state.lock().unwrap(); - - if let Some(consumer) = state.consumers.get_mut(peer_id) { + let mut state = element.imp().state.lock().unwrap(); + if let Some(mut consumer) = state.consumers.get_mut(peer_id) { if let Some(congestion_controller) = consumer.congestion_controller.as_mut() { - congestion_controller.control(element, stats, &mut consumer.encoders); + congestion_controller.loss_control(&element, stats, &mut consumer.encoders); } consumer.stats = stats.to_owned(); } } + fn process_stats(&self, element: &super::WebRTCSink, webrtcbin: gst::Element, peer_id: &str) { + let peer_id = peer_id.to_string(); + let promise = gst::Promise::with_change_func( + glib::clone!(@strong peer_id, @weak element => move |reply| { + if let Ok(Some(stats)) = reply { + + let mut state = element.imp().state.lock().unwrap(); + if let Some(mut consumer) = state.consumers.get_mut(&peer_id) { + if let Some(congestion_controller) = consumer.congestion_controller.as_mut() { + congestion_controller.delay_control(&element, stats, &mut consumer.encoders,); + } + consumer.stats = stats.to_owned(); + } + } + }), + ); + + webrtcbin.emit_by_name::<()>("get-stats", &[&None::, &promise]); + } + fn on_remote_description_set(&self, element: &super::WebRTCSink, peer_id: String) { let mut state = self.state.lock().unwrap(); let mut remove = false; @@ -1874,18 +2021,12 @@ impl WebRTCSink { while interval.next().await.is_some() { let element_clone = element_clone.clone(); let peer_id_clone = peer_id_clone.clone(); - if let Some(webrtcbin) = webrtcbin.upgrade() { - let promise = gst::Promise::with_change_func(move |reply| { - if let Some(element) = element_clone.upgrade() { - let this = Self::from_instance(&element); - - if let Ok(Some(stats)) = reply { - this.process_webrtcbin_stats(&element, &peer_id_clone, stats); - } - } - }); - - webrtcbin.emit_by_name::<()>("get-stats", &[&None::, &promise]); + if let (Some(webrtcbin), Some(element)) = + (webrtcbin.upgrade(), element_clone.upgrade()) + { + element + .imp() + .process_stats(&element, webrtcbin, &peer_id_clone); } else { break; } @@ -2426,6 +2567,7 @@ impl ObjectImpl for WebRTCSink { .set_bitrate(&self.instance(), consumer.max_bitrate as i32); encoder.transceiver.set_property("fec-percentage", 50u32); } + consumer.stats_sigid.take(); } WebRTCSinkCongestionControl::Homegrown => { let _ = consumer.congestion_controller.insert(