diff --git a/net/webrtc/src/webrtcsink/imp.rs b/net/webrtc/src/webrtcsink/imp.rs index c85b267c..82aa2482 100644 --- a/net/webrtc/src/webrtcsink/imp.rs +++ b/net/webrtc/src/webrtcsink/imp.rs @@ -24,6 +24,7 @@ use itertools::Itertools; use once_cell::sync::Lazy; use std::collections::HashMap; +use std::ops::DerefMut; use std::ops::Mul; use std::sync::{mpsc, Arc, Condvar, Mutex}; @@ -295,7 +296,7 @@ pub struct VideoEncoder { stream_name: String, } -struct Session { +struct SessionInner { id: String, pipeline: gst::Pipeline, @@ -325,6 +326,9 @@ struct Session { stats_collection_handle: Option>, } +#[derive(Clone)] +struct Session(Arc>); + #[derive(Debug, PartialEq, Eq, Copy, Clone)] enum SignallerState { Started, @@ -343,156 +347,10 @@ struct SignallerSignals { shutdown: glib::SignalHandlerId, } -struct IceCandidate { - sdp_m_line_index: u32, - candidate: String, -} - -/// Wrapper around `Session`. -/// -/// This makes it possible for the `Session` to be taken out of the `State`, -/// without removing the entry in the `sessions` `HashMap`, thus allowing -/// the `State` lock to be released, e.g. before calling a `Signal`. -/// -/// Taking the `Session`, replaces it with a placeholder which can enqueue -/// items (currently ICE candidates) received while the `Session` is taken. -/// In which case, the enqueued items will be processed when the `Session` is -/// restored. -enum SessionWrapper { - /// The `Session` is available in the `SessionWrapper`. - InPlace(Session), - /// The `Session` was taken out the `SessionWrapper`. - Taken(Vec), -} - -impl SessionWrapper { - /// Unwraps a reference to the `Session` of this `SessionWrapper`. - /// - /// # Panics - /// - /// Panics is the `Session` was taken. - fn unwrap(&self) -> &Session { - match self { - SessionWrapper::InPlace(session) => session, - _ => panic!("Session is not In Place"), - } - } - - /// Unwraps a mutable reference to the `Session` of this `SessionWrapper`. - /// - /// # Panics - /// - /// Panics is the `Session` was taken. - fn unwrap_mut(&mut self) -> &mut Session { - match self { - SessionWrapper::InPlace(session) => session, - _ => panic!("Session is not In Place"), - } - } - - /// Consumes the `SessionWrapper`, returning the wrapped `Session`. - /// - /// # Panics - /// - /// Panics is the `Session` was taken. - fn into_inner(self) -> Session { - match self { - SessionWrapper::InPlace(session) => session, - _ => panic!("Session is not In Place"), - } - } - - /// Takes the `Session` out of this `SessionWrapper`, leaving it in the `Taken` state. - /// - /// # Panics - /// - /// Panics is the `Session` was taken. - fn take(&mut self) -> Session { - use SessionWrapper::*; - match std::mem::replace(self, Taken(Vec::new())) { - InPlace(session) => session, - _ => panic!("Session is not In Place"), - } - } - - /// Restores a `Session` to this `SessionWrapper`. - /// - /// Processes any pending items enqueued while the `Session` was taken. - /// - /// # Panics - /// - /// Panics is the `Session` is already in place. - fn restore(&mut self, element: &super::BaseWebRTCSink, session: Session) { - let SessionWrapper::Taken(ref cands) = self else { - panic!("Session is already in place"); - }; - - if !cands.is_empty() { - gst::trace!( - CAT, - obj = element, - "handling {} pending ice candidates for session {}", - cands.len(), - session.id, - ); - for cand in cands { - session.webrtcbin.emit_by_name::<()>( - "add-ice-candidate", - &[&cand.sdp_m_line_index, &cand.candidate], - ); - } - } - - *self = SessionWrapper::InPlace(session); - } - - /// Adds an ICE candidate to this `SessionWrapper`. - /// - /// If the `Session` is in place, the ICE candidate is added immediately, - /// otherwise, it will be added when the `Session` is restored. - fn add_ice_candidate( - &mut self, - element: &super::BaseWebRTCSink, - session_id: &str, - sdp_m_line_index: u32, - candidate: &str, - ) { - match self { - SessionWrapper::InPlace(session) => { - gst::trace!( - CAT, - obj = element, - "adding ice candidate for session {session_id}" - ); - session - .webrtcbin - .emit_by_name::<()>("add-ice-candidate", &[&sdp_m_line_index, &candidate]); - } - SessionWrapper::Taken(cands) => { - gst::trace!( - CAT, - obj = element, - "queueing ice candidate for session {session_id}" - ); - cands.push(IceCandidate { - sdp_m_line_index, - candidate: candidate.to_string(), - }); - } - } - } -} - -impl From for SessionWrapper { - fn from(session: Session) -> Self { - SessionWrapper::InPlace(session) - } -} - /* Our internal state */ struct State { signaller_state: SignallerState, - sessions: HashMap, + sessions: HashMap, codecs: BTreeMap, /// Used to abort codec discovery codecs_abort_handles: Vec, @@ -1315,7 +1173,7 @@ impl VideoEncoder { } impl State { - fn finalize_session(&mut self, element: &super::BaseWebRTCSink, session: &mut Session) { + fn finalize_session(&mut self, element: &super::BaseWebRTCSink, session: &mut SessionInner) { gst::info!(CAT, "Ending session {}", session.id); session.pipeline.debug_to_dot_file_with_ts( gst::DebugGraphDetails::all(), @@ -1359,8 +1217,7 @@ impl State { session_id: &str, ) -> Option { if let Some(session) = self.sessions.remove(session_id) { - let mut session = session.into_inner(); - self.finalize_session(element, &mut session); + self.finalize_session(element, &mut session.0.lock().unwrap()); Some(session) } else { None @@ -1395,7 +1252,7 @@ impl State { } } -impl Session { +impl SessionInner { fn new( id: String, pipeline: gst::Pipeline, @@ -2366,7 +2223,7 @@ impl BaseWebRTCSink { drop(state); gst::debug!(CAT, imp = self, "Ending sessions"); for session in sessions { - signaller.end_session(&session.id); + signaller.end_session(&session.0.lock().unwrap().id); } gst::debug!(CAT, imp = self, "All sessions have started finalizing"); @@ -2505,6 +2362,8 @@ impl BaseWebRTCSink { if let Some(session) = state.sessions.get(session_id) { session + .0 + .lock() .unwrap() .webrtcbin .emit_by_name::<()>("set-local-description", &[&offer, &None::]); @@ -2528,10 +2387,9 @@ impl BaseWebRTCSink { let settings = self.settings.lock().unwrap(); let signaller = settings.signaller.clone(); drop(settings); - let mut state = self.state.lock().unwrap(); - if let Some(session) = state.sessions.get_mut(session_id) { - let mut session = session.take(); + if let Some(session) = self.state.lock().unwrap().sessions.get(session_id).cloned() { + let mut session = session.0.lock().unwrap(); let sdp = answer.sdp(); session.sdp = Some(sdp.to_owned()); @@ -2543,23 +2401,10 @@ impl BaseWebRTCSink { .and_then(|format| format.parse::().ok()); } - drop(state); - session .webrtcbin .emit_by_name::<()>("set-local-description", &[&answer, &None::]); - let mut state = self.state.lock().unwrap(); - let session_id = session.id.clone(); - - if let Some(session_wrapper) = state.sessions.get_mut(&session_id) { - session_wrapper.restore(&self.obj(), session); - } else { - gst::warning!(CAT, imp = self, "Session {session_id} was removed"); - } - - drop(state); - let maybe_munged_answer = if signaller .has_property("manual-sdp-munging", Some(bool::static_type())) && signaller.property("manual-sdp-munging") @@ -2568,12 +2413,15 @@ impl BaseWebRTCSink { answer } else { // Use the default munging mechanism (signal registered by user) - signaller.munge_sdp(session_id.as_str(), &answer) + signaller.munge_sdp(&session.id, &answer) }; - signaller.send_sdp(&session_id, &maybe_munged_answer); + signaller.send_sdp(&session.id, &maybe_munged_answer); - self.on_remote_description_set(session_id) + let session_id = session.id.clone(); + drop(session); + + self.on_remote_description_set(&session_id) } } @@ -2639,10 +2487,9 @@ impl BaseWebRTCSink { } )); - session - .unwrap() - .webrtcbin - .emit_by_name::<()>("create-answer", &[&None::, &promise]); + let webrtcbin = session.0.lock().unwrap().webrtcbin.clone(); + + webrtcbin.emit_by_name::<()>("create-answer", &[&None::, &promise]); } } @@ -2773,8 +2620,10 @@ impl BaseWebRTCSink { gst::debug!(CAT, imp = self, "Negotiating for session {}", session_id); if let Some(session) = state.sessions.get(session_id) { - let session = session.unwrap(); + let session = session.0.lock().unwrap(); gst::trace!(CAT, imp = self, "WebRTC pads: {:?}", session.webrtc_pads); + let webrtcbin = session.webrtcbin.clone(); + drop(session); if let Some(offer) = offer { let promise = gst::Promise::with_change_func(glib::clone!( @@ -2788,9 +2637,7 @@ impl BaseWebRTCSink { } )); - session - .webrtcbin - .emit_by_name::<()>("set-remote-description", &[&offer, &promise]); + webrtcbin.emit_by_name::<()>("set-remote-description", &[&offer, &promise]); } else { gst::debug!(CAT, imp = self, "Creating offer for session {}", session_id); let promise = gst::Promise::with_change_func(glib::clone!( @@ -2842,9 +2689,7 @@ impl BaseWebRTCSink { } )); - session - .webrtcbin - .emit_by_name::<()>("create-offer", &[&None::, &promise]); + webrtcbin.emit_by_name::<()>("create-offer", &[&None::, &promise]); } } else { gst::debug!( @@ -3196,8 +3041,16 @@ impl BaseWebRTCSink { let state = this.state.lock().unwrap(); if let Some(session) = state.sessions.get(&session_id) { - for webrtc_pad in session.unwrap().webrtc_pads.values() { - if let Some(srcpad) = webrtc_pad.pad.peer() { + let pads: Vec = session + .0 + .lock() + .unwrap() + .webrtc_pads + .values() + .map(|p| p.pad.clone()) + .collect(); + for pad in pads { + if let Some(srcpad) = pad.peer() { srcpad.send_event( gst_video::UpstreamForceKeyUnitEvent::builder() .all_headers(true) @@ -3236,7 +3089,7 @@ impl BaseWebRTCSink { ), ); - let session = Session::new( + let session = SessionInner::new( session_id.clone(), pipeline.clone(), webrtcbin.clone(), @@ -3273,7 +3126,7 @@ impl BaseWebRTCSink { let mut state = element.imp().state.lock().unwrap(); if let Some(session) = state.sessions.get_mut(&session_id_str) { - let session = session.unwrap_mut(); + let mut session = session.0.lock().unwrap(); if session.stats_sigid.is_none() { let session_id_str = session_id_str.clone(); session.stats_sigid = Some(rtp_session.connect_notify( @@ -3383,9 +3236,10 @@ impl BaseWebRTCSink { } }); - state - .sessions - .insert(session_id.to_string(), session.into()); + state.sessions.insert( + session_id.to_string(), + Session(Arc::new(Mutex::new(session))), + ); let mut streams: Vec = state.streams.values().cloned().collect(); @@ -3464,7 +3318,7 @@ impl BaseWebRTCSink { { let mut state = this.state.lock().unwrap(); if let Some(session) = state.sessions.get_mut(&session_id) { - let session = session.unwrap_mut(); + let mut session = session.0.lock().unwrap(); session.webrtc_pads = webrtc_pads; if offer.is_some() { session.codecs = Some(codecs); @@ -3540,6 +3394,7 @@ impl BaseWebRTCSink { if let Some(session) = state.end_session(&self.obj(), session_id) { drop(state); + let session = session.0.lock().unwrap(); signaller .emit_by_name::<()>("consumer-removed", &[&session.peer_id, &session.webrtcbin]); if signal { @@ -3555,7 +3410,9 @@ impl BaseWebRTCSink { fn process_loss_stats(&self, session_id: &str, stats: &gst::Structure) { let mut state = self.state.lock().unwrap(); if let Some(session) = state.sessions.get_mut(session_id) { - let session = session.unwrap_mut(); + /* We need this two-step approach for split-borrowing */ + let mut session_guard = session.0.lock().unwrap(); + let session = session_guard.deref_mut(); if let Some(congestion_controller) = session.congestion_controller.as_mut() { congestion_controller.loss_control(&self.obj(), stats, &mut session.encoders); } @@ -3574,7 +3431,9 @@ impl BaseWebRTCSink { if let Ok(Some(stats)) = reply { let mut state = this.state.lock().unwrap(); if let Some(session) = state.sessions.get_mut(&session_id) { - let session = session.unwrap_mut(); + /* We need this two-step approach for split-borrowing */ + let mut session_guard = session.0.lock().unwrap(); + let session = session_guard.deref_mut(); if let Some(congestion_controller) = session.congestion_controller.as_mut() { congestion_controller.delay_control( @@ -3597,7 +3456,7 @@ impl BaseWebRTCSink { let mut state = self.state.lock().unwrap(); if let Some(session) = state.sessions.get_mut(session_id) { - session.unwrap_mut().rtprtxsend = Some(rtprtxsend); + session.0.lock().unwrap().rtprtxsend = Some(rtprtxsend); } } @@ -3607,7 +3466,7 @@ impl BaseWebRTCSink { let mut state = self.state.lock().unwrap(); if let Some(session) = state.sessions.get_mut(session_id) { - let session = session.unwrap_mut(); + let mut session = session.0.lock().unwrap(); let n_encoders = session.encoders.len(); @@ -3677,106 +3536,106 @@ impl BaseWebRTCSink { } } - fn on_remote_description_set(&self, session_id: String) { - let mut state = self.state.lock().unwrap(); + fn on_remote_description_set(&self, session_id: &str) { + let mut state_guard = self.state.lock().unwrap(); + let mut state = state_guard.deref_mut(); + let Some(session_clone) = state.sessions.get(session_id).map(|s| s.0.clone()) else { + return; + }; let mut remove = false; let codecs = state.codecs.clone(); - if let Some(session) = state.sessions.get_mut(&session_id) { - let mut session = session.take(); + let mut session = session_clone.lock().unwrap(); - for webrtc_pad in session.webrtc_pads.clone().values() { - let transceiver = webrtc_pad - .pad - .property::("transceiver"); + for webrtc_pad in session.webrtc_pads.clone().values() { + let transceiver = webrtc_pad + .pad + .property::("transceiver"); - let Some(ref stream_name) = webrtc_pad.stream_name else { - continue; - }; + let Some(ref stream_name) = webrtc_pad.stream_name else { + continue; + }; - if let Some(mid) = transceiver.mid() { - state - .session_mids - .entry(session_id.clone()) - .or_default() - .insert(mid.to_string(), stream_name.clone()); - state - .session_stream_names - .entry(session_id.clone()) - .or_default() - .insert(stream_name.clone(), mid.to_string()); - } + if let Some(mid) = transceiver.mid() { + state + .session_mids + .entry(session.id.clone()) + .or_default() + .insert(mid.to_string(), stream_name.clone()); + state + .session_stream_names + .entry(session.id.clone()) + .or_default() + .insert(stream_name.clone(), mid.to_string()); + } - if let Some(producer) = state - .streams - .get(stream_name) - .and_then(|stream| stream.producer.clone()) + if let Some(producer) = state + .streams + .get(stream_name) + .and_then(|stream| stream.producer.clone()) + { + drop(state_guard); + if let Err(err) = + session.connect_input_stream(&self.obj(), &producer, webrtc_pad, &codecs) { - drop(state); - if let Err(err) = - session.connect_input_stream(&self.obj(), &producer, webrtc_pad, &codecs) - { - gst::error!( - CAT, - imp = self, - "Failed to connect input stream {} for session {}: {}", - stream_name, - session_id, - err - ); - remove = true; - state = self.state.lock().unwrap(); - break; - } - state = self.state.lock().unwrap(); - } else { gst::error!( CAT, imp = self, - "No producer to connect session {} to", - session_id, + "Failed to connect input stream {} for session {}: {}", + stream_name, + session.id, + err ); remove = true; + state_guard = self.state.lock().unwrap(); + state = state_guard.deref_mut(); + break; + } + drop(session); + state_guard = self.state.lock().unwrap(); + state = state_guard.deref_mut(); + session = session_clone.lock().unwrap(); + } else { + gst::error!( + CAT, + imp = self, + "No producer to connect session {} to", + session.id, + ); + remove = true; + break; + } + } + + session.pipeline.debug_to_dot_file_with_ts( + gst::DebugGraphDetails::all(), + format!("webrtcsink-peer-{}-remote-description-set", session.id), + ); + + let this_weak = self.downgrade(); + let webrtcbin = session.webrtcbin.downgrade(); + let session_id_clone = session.id.clone(); + session.stats_collection_handle = Some(RUNTIME.spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_millis(100)); + + loop { + interval.tick().await; + if let (Some(webrtcbin), Some(this)) = (webrtcbin.upgrade(), this_weak.upgrade()) { + this.process_stats(webrtcbin, &session_id_clone); + } else { break; } } + })); - session.pipeline.debug_to_dot_file_with_ts( - gst::DebugGraphDetails::all(), - format!("webrtcsink-peer-{session_id}-remote-description-set",), - ); - - let this_weak = self.downgrade(); - let webrtcbin = session.webrtcbin.downgrade(); - let session_id_clone = session_id.clone(); - session.stats_collection_handle = Some(RUNTIME.spawn(async move { - let mut interval = tokio::time::interval(std::time::Duration::from_millis(100)); - - loop { - interval.tick().await; - if let (Some(webrtcbin), Some(this)) = - (webrtcbin.upgrade(), this_weak.upgrade()) - { - this.process_stats(webrtcbin, &session_id_clone); - } else { - break; - } - } - })); - - if remove { - let _ = state.sessions.remove(&session_id); - state.finalize_session(&self.obj(), &mut session); - drop(state); - let settings = self.settings.lock().unwrap(); - let signaller = settings.signaller.clone(); - drop(settings); - signaller.end_session(&session_id); - } else if let Some(session_wrapper) = state.sessions.get_mut(&session_id) { - session_wrapper.restore(&self.obj(), session); - } else { - gst::warning!(CAT, imp = self, "Session {session_id} was removed"); - } + if remove { + let _ = state.sessions.remove(&session.id); + state.finalize_session(&self.obj(), &mut session); + drop(state_guard); + let settings = self.settings.lock().unwrap(); + let signaller = settings.signaller.clone(); + drop(settings); + signaller.end_session(&session.id); } } @@ -3788,7 +3647,7 @@ impl BaseWebRTCSink { _sdp_mid: Option, candidate: &str, ) { - let mut state = self.state.lock().unwrap(); + let state = self.state.lock().unwrap(); let sdp_m_line_index = match sdp_m_line_index { Some(sdp_m_line_index) => sdp_m_line_index, @@ -3798,8 +3657,16 @@ impl BaseWebRTCSink { } }; - if let Some(session_wrapper) = state.sessions.get_mut(session_id) { - session_wrapper.add_ice_candidate(&self.obj(), session_id, sdp_m_line_index, candidate); + if let Some(session) = state.sessions.get(session_id) { + let session = session.0.lock().unwrap(); + gst::trace!( + CAT, + imp = self, + "adding ice candidate for session {session_id}" + ); + session + .webrtcbin + .emit_by_name::<()>("add-ice-candidate", &[&sdp_m_line_index, &candidate]); } else { gst::warning!(CAT, imp = self, "No consumer with ID {session_id}"); } @@ -3808,8 +3675,8 @@ impl BaseWebRTCSink { fn handle_sdp_answer(&self, session_id: &str, desc: &gst_webrtc::WebRTCSessionDescription) { let mut state = self.state.lock().unwrap(); - if let Some(session) = state.sessions.get_mut(session_id) { - let session = session.unwrap_mut(); + if let Some(session) = state.sessions.get(session_id).map(|s| s.0.clone()) { + let mut session = session.lock().unwrap(); let sdp = desc.sdp(); @@ -3886,13 +3753,13 @@ impl BaseWebRTCSink { session_id, move |reply| { gst::debug!(CAT, imp = this, "received reply {:?}", reply); - this.on_remote_description_set(session_id); + this.on_remote_description_set(&session_id); } )); - session - .webrtcbin - .emit_by_name::<()>("set-remote-description", &[desc, &promise]); + let webrtcbin = session.webrtcbin.clone(); + drop(session); + webrtcbin.emit_by_name::<()>("set-remote-description", &[desc, &promise]); } else { gst::warning!(CAT, imp = self, "No consumer with ID {session_id}"); } @@ -4179,7 +4046,7 @@ impl BaseWebRTCSink { .map(|(name, consumer)| { ( name.as_str(), - consumer.unwrap().gather_stats().to_send_value(), + consumer.0.lock().unwrap().gather_stats().to_send_value(), ) }), ) @@ -4264,16 +4131,15 @@ impl BaseWebRTCSink { // update video encoder info used when downscaling/downsampling the input let stream_name = pad.name().to_string(); - state - .sessions - .values_mut() - .flat_map(|session| session.unwrap_mut().encoders.iter_mut()) - .filter(|encoder| encoder.stream_name == stream_name) - .for_each(|encoder| { - encoder.halved_framerate = - video_info.fps().mul(gst::Fraction::new(1, 2)); - encoder.video_info = video_info.clone(); - }); + for session in state.sessions.values() { + for encoder in session.0.lock().unwrap().encoders.iter_mut() { + if encoder.stream_name == stream_name { + encoder.halved_framerate = + video_info.fps().mul(gst::Fraction::new(1, 2)); + encoder.video_info = video_info.clone(); + } + } + } } } }