webrtcsink: fix session not in place errors

The InPlace/Taken logic was introduced to avoid using an extra lock
around the session, but it places expectations that are not always
obvious to meet around when a session is expected to be taken or not.

Any code that expects to have access to the sessions at all times thus
needs either extra logic in the session wrapper, or to maintain the
state of the session outside of the session (eg mids).

This commit removes the logic, and wraps sessions in Arc<Mutex>>.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1852>
This commit is contained in:
Mathieu Duponchelle 2024-10-15 17:16:19 +02:00 committed by GStreamer Marge Bot
parent ef06421a25
commit 959463ff65

View file

@ -24,6 +24,7 @@ use itertools::Itertools;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::ops::Mul;
use std::sync::{mpsc, Arc, Condvar, Mutex};
@ -295,7 +296,7 @@ pub struct VideoEncoder {
stream_name: String,
}
struct Session {
struct SessionInner {
id: String,
pipeline: gst::Pipeline,
@ -325,6 +326,9 @@ struct Session {
stats_collection_handle: Option<tokio::task::JoinHandle<()>>,
}
#[derive(Clone)]
struct Session(Arc<Mutex<SessionInner>>);
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
enum SignallerState {
Started,
@ -343,156 +347,10 @@ struct SignallerSignals {
shutdown: glib::SignalHandlerId,
}
struct IceCandidate {
sdp_m_line_index: u32,
candidate: String,
}
/// Wrapper around `Session`.
///
/// This makes it possible for the `Session` to be taken out of the `State`,
/// without removing the entry in the `sessions` `HashMap`, thus allowing
/// the `State` lock to be released, e.g. before calling a `Signal`.
///
/// Taking the `Session`, replaces it with a placeholder which can enqueue
/// items (currently ICE candidates) received while the `Session` is taken.
/// In which case, the enqueued items will be processed when the `Session` is
/// restored.
enum SessionWrapper {
/// The `Session` is available in the `SessionWrapper`.
InPlace(Session),
/// The `Session` was taken out the `SessionWrapper`.
Taken(Vec<IceCandidate>),
}
impl SessionWrapper {
/// Unwraps a reference to the `Session` of this `SessionWrapper`.
///
/// # Panics
///
/// Panics is the `Session` was taken.
fn unwrap(&self) -> &Session {
match self {
SessionWrapper::InPlace(session) => session,
_ => panic!("Session is not In Place"),
}
}
/// Unwraps a mutable reference to the `Session` of this `SessionWrapper`.
///
/// # Panics
///
/// Panics is the `Session` was taken.
fn unwrap_mut(&mut self) -> &mut Session {
match self {
SessionWrapper::InPlace(session) => session,
_ => panic!("Session is not In Place"),
}
}
/// Consumes the `SessionWrapper`, returning the wrapped `Session`.
///
/// # Panics
///
/// Panics is the `Session` was taken.
fn into_inner(self) -> Session {
match self {
SessionWrapper::InPlace(session) => session,
_ => panic!("Session is not In Place"),
}
}
/// Takes the `Session` out of this `SessionWrapper`, leaving it in the `Taken` state.
///
/// # Panics
///
/// Panics is the `Session` was taken.
fn take(&mut self) -> Session {
use SessionWrapper::*;
match std::mem::replace(self, Taken(Vec::new())) {
InPlace(session) => session,
_ => panic!("Session is not In Place"),
}
}
/// Restores a `Session` to this `SessionWrapper`.
///
/// Processes any pending items enqueued while the `Session` was taken.
///
/// # Panics
///
/// Panics is the `Session` is already in place.
fn restore(&mut self, element: &super::BaseWebRTCSink, session: Session) {
let SessionWrapper::Taken(ref cands) = self else {
panic!("Session is already in place");
};
if !cands.is_empty() {
gst::trace!(
CAT,
obj = element,
"handling {} pending ice candidates for session {}",
cands.len(),
session.id,
);
for cand in cands {
session.webrtcbin.emit_by_name::<()>(
"add-ice-candidate",
&[&cand.sdp_m_line_index, &cand.candidate],
);
}
}
*self = SessionWrapper::InPlace(session);
}
/// Adds an ICE candidate to this `SessionWrapper`.
///
/// If the `Session` is in place, the ICE candidate is added immediately,
/// otherwise, it will be added when the `Session` is restored.
fn add_ice_candidate(
&mut self,
element: &super::BaseWebRTCSink,
session_id: &str,
sdp_m_line_index: u32,
candidate: &str,
) {
match self {
SessionWrapper::InPlace(session) => {
gst::trace!(
CAT,
obj = element,
"adding ice candidate for session {session_id}"
);
session
.webrtcbin
.emit_by_name::<()>("add-ice-candidate", &[&sdp_m_line_index, &candidate]);
}
SessionWrapper::Taken(cands) => {
gst::trace!(
CAT,
obj = element,
"queueing ice candidate for session {session_id}"
);
cands.push(IceCandidate {
sdp_m_line_index,
candidate: candidate.to_string(),
});
}
}
}
}
impl From<Session> for SessionWrapper {
fn from(session: Session) -> Self {
SessionWrapper::InPlace(session)
}
}
/* Our internal state */
struct State {
signaller_state: SignallerState,
sessions: HashMap<String, SessionWrapper>,
sessions: HashMap<String, Session>,
codecs: BTreeMap<i32, Codec>,
/// Used to abort codec discovery
codecs_abort_handles: Vec<futures::future::AbortHandle>,
@ -1315,7 +1173,7 @@ impl VideoEncoder {
}
impl State {
fn finalize_session(&mut self, element: &super::BaseWebRTCSink, session: &mut Session) {
fn finalize_session(&mut self, element: &super::BaseWebRTCSink, session: &mut SessionInner) {
gst::info!(CAT, "Ending session {}", session.id);
session.pipeline.debug_to_dot_file_with_ts(
gst::DebugGraphDetails::all(),
@ -1359,8 +1217,7 @@ impl State {
session_id: &str,
) -> Option<Session> {
if let Some(session) = self.sessions.remove(session_id) {
let mut session = session.into_inner();
self.finalize_session(element, &mut session);
self.finalize_session(element, &mut session.0.lock().unwrap());
Some(session)
} else {
None
@ -1395,7 +1252,7 @@ impl State {
}
}
impl Session {
impl SessionInner {
fn new(
id: String,
pipeline: gst::Pipeline,
@ -2366,7 +2223,7 @@ impl BaseWebRTCSink {
drop(state);
gst::debug!(CAT, imp = self, "Ending sessions");
for session in sessions {
signaller.end_session(&session.id);
signaller.end_session(&session.0.lock().unwrap().id);
}
gst::debug!(CAT, imp = self, "All sessions have started finalizing");
@ -2505,6 +2362,8 @@ impl BaseWebRTCSink {
if let Some(session) = state.sessions.get(session_id) {
session
.0
.lock()
.unwrap()
.webrtcbin
.emit_by_name::<()>("set-local-description", &[&offer, &None::<gst::Promise>]);
@ -2528,10 +2387,9 @@ impl BaseWebRTCSink {
let settings = self.settings.lock().unwrap();
let signaller = settings.signaller.clone();
drop(settings);
let mut state = self.state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(session_id) {
let mut session = session.take();
if let Some(session) = self.state.lock().unwrap().sessions.get(session_id).cloned() {
let mut session = session.0.lock().unwrap();
let sdp = answer.sdp();
session.sdp = Some(sdp.to_owned());
@ -2543,23 +2401,10 @@ impl BaseWebRTCSink {
.and_then(|format| format.parse::<i32>().ok());
}
drop(state);
session
.webrtcbin
.emit_by_name::<()>("set-local-description", &[&answer, &None::<gst::Promise>]);
let mut state = self.state.lock().unwrap();
let session_id = session.id.clone();
if let Some(session_wrapper) = state.sessions.get_mut(&session_id) {
session_wrapper.restore(&self.obj(), session);
} else {
gst::warning!(CAT, imp = self, "Session {session_id} was removed");
}
drop(state);
let maybe_munged_answer = if signaller
.has_property("manual-sdp-munging", Some(bool::static_type()))
&& signaller.property("manual-sdp-munging")
@ -2568,12 +2413,15 @@ impl BaseWebRTCSink {
answer
} else {
// Use the default munging mechanism (signal registered by user)
signaller.munge_sdp(session_id.as_str(), &answer)
signaller.munge_sdp(&session.id, &answer)
};
signaller.send_sdp(&session_id, &maybe_munged_answer);
signaller.send_sdp(&session.id, &maybe_munged_answer);
self.on_remote_description_set(session_id)
let session_id = session.id.clone();
drop(session);
self.on_remote_description_set(&session_id)
}
}
@ -2639,10 +2487,9 @@ impl BaseWebRTCSink {
}
));
session
.unwrap()
.webrtcbin
.emit_by_name::<()>("create-answer", &[&None::<gst::Structure>, &promise]);
let webrtcbin = session.0.lock().unwrap().webrtcbin.clone();
webrtcbin.emit_by_name::<()>("create-answer", &[&None::<gst::Structure>, &promise]);
}
}
@ -2773,8 +2620,10 @@ impl BaseWebRTCSink {
gst::debug!(CAT, imp = self, "Negotiating for session {}", session_id);
if let Some(session) = state.sessions.get(session_id) {
let session = session.unwrap();
let session = session.0.lock().unwrap();
gst::trace!(CAT, imp = self, "WebRTC pads: {:?}", session.webrtc_pads);
let webrtcbin = session.webrtcbin.clone();
drop(session);
if let Some(offer) = offer {
let promise = gst::Promise::with_change_func(glib::clone!(
@ -2788,9 +2637,7 @@ impl BaseWebRTCSink {
}
));
session
.webrtcbin
.emit_by_name::<()>("set-remote-description", &[&offer, &promise]);
webrtcbin.emit_by_name::<()>("set-remote-description", &[&offer, &promise]);
} else {
gst::debug!(CAT, imp = self, "Creating offer for session {}", session_id);
let promise = gst::Promise::with_change_func(glib::clone!(
@ -2842,9 +2689,7 @@ impl BaseWebRTCSink {
}
));
session
.webrtcbin
.emit_by_name::<()>("create-offer", &[&None::<gst::Structure>, &promise]);
webrtcbin.emit_by_name::<()>("create-offer", &[&None::<gst::Structure>, &promise]);
}
} else {
gst::debug!(
@ -3196,8 +3041,16 @@ impl BaseWebRTCSink {
let state = this.state.lock().unwrap();
if let Some(session) = state.sessions.get(&session_id) {
for webrtc_pad in session.unwrap().webrtc_pads.values() {
if let Some(srcpad) = webrtc_pad.pad.peer() {
let pads: Vec<gst::Pad> = session
.0
.lock()
.unwrap()
.webrtc_pads
.values()
.map(|p| p.pad.clone())
.collect();
for pad in pads {
if let Some(srcpad) = pad.peer() {
srcpad.send_event(
gst_video::UpstreamForceKeyUnitEvent::builder()
.all_headers(true)
@ -3236,7 +3089,7 @@ impl BaseWebRTCSink {
),
);
let session = Session::new(
let session = SessionInner::new(
session_id.clone(),
pipeline.clone(),
webrtcbin.clone(),
@ -3273,7 +3126,7 @@ impl BaseWebRTCSink {
let mut state = element.imp().state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(&session_id_str) {
let session = session.unwrap_mut();
let mut session = session.0.lock().unwrap();
if session.stats_sigid.is_none() {
let session_id_str = session_id_str.clone();
session.stats_sigid = Some(rtp_session.connect_notify(
@ -3383,9 +3236,10 @@ impl BaseWebRTCSink {
}
});
state
.sessions
.insert(session_id.to_string(), session.into());
state.sessions.insert(
session_id.to_string(),
Session(Arc::new(Mutex::new(session))),
);
let mut streams: Vec<InputStream> = state.streams.values().cloned().collect();
@ -3464,7 +3318,7 @@ impl BaseWebRTCSink {
{
let mut state = this.state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(&session_id) {
let session = session.unwrap_mut();
let mut session = session.0.lock().unwrap();
session.webrtc_pads = webrtc_pads;
if offer.is_some() {
session.codecs = Some(codecs);
@ -3540,6 +3394,7 @@ impl BaseWebRTCSink {
if let Some(session) = state.end_session(&self.obj(), session_id) {
drop(state);
let session = session.0.lock().unwrap();
signaller
.emit_by_name::<()>("consumer-removed", &[&session.peer_id, &session.webrtcbin]);
if signal {
@ -3555,7 +3410,9 @@ impl BaseWebRTCSink {
fn process_loss_stats(&self, session_id: &str, stats: &gst::Structure) {
let mut state = self.state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(session_id) {
let session = session.unwrap_mut();
/* We need this two-step approach for split-borrowing */
let mut session_guard = session.0.lock().unwrap();
let session = session_guard.deref_mut();
if let Some(congestion_controller) = session.congestion_controller.as_mut() {
congestion_controller.loss_control(&self.obj(), stats, &mut session.encoders);
}
@ -3574,7 +3431,9 @@ impl BaseWebRTCSink {
if let Ok(Some(stats)) = reply {
let mut state = this.state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(&session_id) {
let session = session.unwrap_mut();
/* We need this two-step approach for split-borrowing */
let mut session_guard = session.0.lock().unwrap();
let session = session_guard.deref_mut();
if let Some(congestion_controller) = session.congestion_controller.as_mut()
{
congestion_controller.delay_control(
@ -3597,7 +3456,7 @@ impl BaseWebRTCSink {
let mut state = self.state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(session_id) {
session.unwrap_mut().rtprtxsend = Some(rtprtxsend);
session.0.lock().unwrap().rtprtxsend = Some(rtprtxsend);
}
}
@ -3607,7 +3466,7 @@ impl BaseWebRTCSink {
let mut state = self.state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(session_id) {
let session = session.unwrap_mut();
let mut session = session.0.lock().unwrap();
let n_encoders = session.encoders.len();
@ -3677,13 +3536,16 @@ impl BaseWebRTCSink {
}
}
fn on_remote_description_set(&self, session_id: String) {
let mut state = self.state.lock().unwrap();
fn on_remote_description_set(&self, session_id: &str) {
let mut state_guard = self.state.lock().unwrap();
let mut state = state_guard.deref_mut();
let Some(session_clone) = state.sessions.get(session_id).map(|s| s.0.clone()) else {
return;
};
let mut remove = false;
let codecs = state.codecs.clone();
if let Some(session) = state.sessions.get_mut(&session_id) {
let mut session = session.take();
let mut session = session_clone.lock().unwrap();
for webrtc_pad in session.webrtc_pads.clone().values() {
let transceiver = webrtc_pad
@ -3697,12 +3559,12 @@ impl BaseWebRTCSink {
if let Some(mid) = transceiver.mid() {
state
.session_mids
.entry(session_id.clone())
.entry(session.id.clone())
.or_default()
.insert(mid.to_string(), stream_name.clone());
state
.session_stream_names
.entry(session_id.clone())
.entry(session.id.clone())
.or_default()
.insert(stream_name.clone(), mid.to_string());
}
@ -3712,7 +3574,7 @@ impl BaseWebRTCSink {
.get(stream_name)
.and_then(|stream| stream.producer.clone())
{
drop(state);
drop(state_guard);
if let Err(err) =
session.connect_input_stream(&self.obj(), &producer, webrtc_pad, &codecs)
{
@ -3721,20 +3583,24 @@ impl BaseWebRTCSink {
imp = self,
"Failed to connect input stream {} for session {}: {}",
stream_name,
session_id,
session.id,
err
);
remove = true;
state = self.state.lock().unwrap();
state_guard = self.state.lock().unwrap();
state = state_guard.deref_mut();
break;
}
state = self.state.lock().unwrap();
drop(session);
state_guard = self.state.lock().unwrap();
state = state_guard.deref_mut();
session = session_clone.lock().unwrap();
} else {
gst::error!(
CAT,
imp = self,
"No producer to connect session {} to",
session_id,
session.id,
);
remove = true;
break;
@ -3743,20 +3609,18 @@ impl BaseWebRTCSink {
session.pipeline.debug_to_dot_file_with_ts(
gst::DebugGraphDetails::all(),
format!("webrtcsink-peer-{session_id}-remote-description-set",),
format!("webrtcsink-peer-{}-remote-description-set", session.id),
);
let this_weak = self.downgrade();
let webrtcbin = session.webrtcbin.downgrade();
let session_id_clone = session_id.clone();
let session_id_clone = session.id.clone();
session.stats_collection_handle = Some(RUNTIME.spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_millis(100));
loop {
interval.tick().await;
if let (Some(webrtcbin), Some(this)) =
(webrtcbin.upgrade(), this_weak.upgrade())
{
if let (Some(webrtcbin), Some(this)) = (webrtcbin.upgrade(), this_weak.upgrade()) {
this.process_stats(webrtcbin, &session_id_clone);
} else {
break;
@ -3765,18 +3629,13 @@ impl BaseWebRTCSink {
}));
if remove {
let _ = state.sessions.remove(&session_id);
let _ = state.sessions.remove(&session.id);
state.finalize_session(&self.obj(), &mut session);
drop(state);
drop(state_guard);
let settings = self.settings.lock().unwrap();
let signaller = settings.signaller.clone();
drop(settings);
signaller.end_session(&session_id);
} else if let Some(session_wrapper) = state.sessions.get_mut(&session_id) {
session_wrapper.restore(&self.obj(), session);
} else {
gst::warning!(CAT, imp = self, "Session {session_id} was removed");
}
signaller.end_session(&session.id);
}
}
@ -3788,7 +3647,7 @@ impl BaseWebRTCSink {
_sdp_mid: Option<String>,
candidate: &str,
) {
let mut state = self.state.lock().unwrap();
let state = self.state.lock().unwrap();
let sdp_m_line_index = match sdp_m_line_index {
Some(sdp_m_line_index) => sdp_m_line_index,
@ -3798,8 +3657,16 @@ impl BaseWebRTCSink {
}
};
if let Some(session_wrapper) = state.sessions.get_mut(session_id) {
session_wrapper.add_ice_candidate(&self.obj(), session_id, sdp_m_line_index, candidate);
if let Some(session) = state.sessions.get(session_id) {
let session = session.0.lock().unwrap();
gst::trace!(
CAT,
imp = self,
"adding ice candidate for session {session_id}"
);
session
.webrtcbin
.emit_by_name::<()>("add-ice-candidate", &[&sdp_m_line_index, &candidate]);
} else {
gst::warning!(CAT, imp = self, "No consumer with ID {session_id}");
}
@ -3808,8 +3675,8 @@ impl BaseWebRTCSink {
fn handle_sdp_answer(&self, session_id: &str, desc: &gst_webrtc::WebRTCSessionDescription) {
let mut state = self.state.lock().unwrap();
if let Some(session) = state.sessions.get_mut(session_id) {
let session = session.unwrap_mut();
if let Some(session) = state.sessions.get(session_id).map(|s| s.0.clone()) {
let mut session = session.lock().unwrap();
let sdp = desc.sdp();
@ -3886,13 +3753,13 @@ impl BaseWebRTCSink {
session_id,
move |reply| {
gst::debug!(CAT, imp = this, "received reply {:?}", reply);
this.on_remote_description_set(session_id);
this.on_remote_description_set(&session_id);
}
));
session
.webrtcbin
.emit_by_name::<()>("set-remote-description", &[desc, &promise]);
let webrtcbin = session.webrtcbin.clone();
drop(session);
webrtcbin.emit_by_name::<()>("set-remote-description", &[desc, &promise]);
} else {
gst::warning!(CAT, imp = self, "No consumer with ID {session_id}");
}
@ -4179,7 +4046,7 @@ impl BaseWebRTCSink {
.map(|(name, consumer)| {
(
name.as_str(),
consumer.unwrap().gather_stats().to_send_value(),
consumer.0.lock().unwrap().gather_stats().to_send_value(),
)
}),
)
@ -4264,16 +4131,15 @@ impl BaseWebRTCSink {
// update video encoder info used when downscaling/downsampling the input
let stream_name = pad.name().to_string();
state
.sessions
.values_mut()
.flat_map(|session| session.unwrap_mut().encoders.iter_mut())
.filter(|encoder| encoder.stream_name == stream_name)
.for_each(|encoder| {
for session in state.sessions.values() {
for encoder in session.0.lock().unwrap().encoders.iter_mut() {
if encoder.stream_name == stream_name {
encoder.halved_framerate =
video_info.fps().mul(gst::Fraction::new(1, 2));
encoder.video_info = video_info.clone();
});
}
}
}
}
}
}