aws: Fix race condition when unlocking

It would be possible that there is no cancellable yet when unlock() is
called, then a new future is executed and it wouldn't have any
information that it is not supposed to run at all.

To solve this remember if unlock() was called and reset this in
unlock_stop().

Also implement actual unlocking in s3hlssink.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1602>
This commit is contained in:
Sebastian Dröge 2024-06-07 18:55:43 +03:00 committed by GStreamer Marge Bot
parent 00aaecad07
commit 51f6d3986f
5 changed files with 92 additions and 64 deletions

View file

@ -8,7 +8,6 @@
// //
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use futures::future;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::io::Write; use std::io::Write;
use std::str::FromStr; use std::str::FromStr;
@ -90,7 +89,7 @@ pub struct S3HlsSink {
settings: Mutex<Settings>, settings: Mutex<Settings>,
state: Mutex<State>, state: Mutex<State>,
hlssink: gst::Element, hlssink: gst::Element,
canceller: Mutex<Option<future::AbortHandle>>, canceller: Mutex<s3utils::Canceller>,
} }
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| { static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
@ -459,7 +458,7 @@ impl ObjectSubclass for S3HlsSink {
settings: Mutex::new(Settings::default()), settings: Mutex::new(Settings::default()),
state: Mutex::new(State::Stopped), state: Mutex::new(State::Stopped),
hlssink, hlssink,
canceller: Mutex::new(None), canceller: Mutex::new(s3utils::Canceller::default()),
} }
} }
} }
@ -803,10 +802,19 @@ impl ElementImpl for S3HlsSink {
PAD_TEMPLATES.as_ref() PAD_TEMPLATES.as_ref()
} }
#[allow(clippy::single_match)]
fn change_state( fn change_state(
&self, &self,
transition: gst::StateChange, transition: gst::StateChange,
) -> Result<gst::StateChangeSuccess, gst::StateChangeError> { ) -> Result<gst::StateChangeSuccess, gst::StateChangeError> {
match transition {
gst::StateChange::PausedToReady => {
let mut canceller = self.canceller.lock().unwrap();
canceller.abort();
}
_ => (),
}
let ret = self.parent_change_state(transition)?; let ret = self.parent_change_state(transition)?;
/* /*
* The settings lock must not be taken before the parent state change. * The settings lock must not be taken before the parent state change.
@ -850,6 +858,11 @@ impl ElementImpl for S3HlsSink {
} }
} }
gst::StateChange::PausedToReady => {
let mut canceller = self.canceller.lock().unwrap();
*canceller = s3utils::Canceller::None;
}
gst::StateChange::ReadyToNull => { gst::StateChange::ReadyToNull => {
drop(settings); drop(settings);
/* /*

View file

@ -26,7 +26,6 @@ use aws_sdk_s3::{
Client, Client,
}; };
use futures::future;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::From; use std::convert::From;
@ -186,8 +185,8 @@ pub struct S3Sink {
url: Mutex<Option<GstS3Url>>, url: Mutex<Option<GstS3Url>>,
settings: Mutex<Settings>, settings: Mutex<Settings>,
state: Mutex<State>, state: Mutex<State>,
canceller: Mutex<Option<future::AbortHandle>>, canceller: Mutex<s3utils::Canceller>,
abort_multipart_canceller: Mutex<Option<future::AbortHandle>>, abort_multipart_canceller: Mutex<s3utils::Canceller>,
} }
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| { static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
@ -618,19 +617,6 @@ impl S3Sink {
Ok(()) Ok(())
} }
fn cancel(&self) {
let mut canceller = self.canceller.lock().unwrap();
let mut abort_canceller = self.abort_multipart_canceller.lock().unwrap();
if let Some(c) = abort_canceller.take() {
c.abort()
};
if let Some(c) = canceller.take() {
c.abort()
};
}
fn set_uri(self: &S3Sink, url_str: Option<&str>) -> Result<(), glib::Error> { fn set_uri(self: &S3Sink, url_str: Option<&str>) -> Result<(), glib::Error> {
let state = self.state.lock().unwrap(); let state = self.state.lock().unwrap();
@ -1103,8 +1089,18 @@ impl BaseSinkImpl for S3Sink {
} }
fn unlock(&self) -> Result<(), gst::ErrorMessage> { fn unlock(&self) -> Result<(), gst::ErrorMessage> {
self.cancel(); let mut canceller = self.canceller.lock().unwrap();
let mut abort_canceller = self.abort_multipart_canceller.lock().unwrap();
canceller.abort();
abort_canceller.abort();
Ok(())
}
fn unlock_stop(&self) -> Result<(), gst::ErrorMessage> {
let mut canceller = self.canceller.lock().unwrap();
let mut abort_canceller = self.abort_multipart_canceller.lock().unwrap();
*canceller = s3utils::Canceller::None;
*abort_canceller = s3utils::Canceller::None;
Ok(()) Ok(())
} }

View file

@ -21,7 +21,6 @@ use aws_sdk_s3::{
Client, Client,
}; };
use futures::future;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::From; use std::convert::From;
@ -158,7 +157,7 @@ pub struct S3PutObjectSink {
url: Mutex<Option<GstS3Url>>, url: Mutex<Option<GstS3Url>>,
settings: Mutex<Settings>, settings: Mutex<Settings>,
state: Mutex<State>, state: Mutex<State>,
canceller: Mutex<Option<future::AbortHandle>>, canceller: Mutex<s3utils::Canceller>,
} }
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| { static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
@ -328,14 +327,6 @@ impl S3PutObjectSink {
Ok(()) Ok(())
} }
fn cancel(&self) {
let mut canceller = self.canceller.lock().unwrap();
if let Some(c) = canceller.take() {
c.abort()
};
}
fn set_uri(self: &S3PutObjectSink, url_str: Option<&str>) -> Result<(), glib::Error> { fn set_uri(self: &S3PutObjectSink, url_str: Option<&str>) -> Result<(), glib::Error> {
let state = self.state.lock().unwrap(); let state = self.state.lock().unwrap();
@ -756,8 +747,14 @@ impl BaseSinkImpl for S3PutObjectSink {
} }
fn unlock(&self) -> Result<(), gst::ErrorMessage> { fn unlock(&self) -> Result<(), gst::ErrorMessage> {
self.cancel(); let mut canceller = self.canceller.lock().unwrap();
canceller.abort();
Ok(())
}
fn unlock_stop(&self) -> Result<(), gst::ErrorMessage> {
let mut canceller = self.canceller.lock().unwrap();
*canceller = s3utils::Canceller::None;
Ok(()) Ok(())
} }

View file

@ -7,7 +7,6 @@
// SPDX-License-Identifier: MPL-2.0 // SPDX-License-Identifier: MPL-2.0
use bytes::Bytes; use bytes::Bytes;
use futures::future;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::sync::Mutex; use std::sync::Mutex;
use std::time::Duration; use std::time::Duration;
@ -77,7 +76,7 @@ impl Default for Settings {
pub struct S3Src { pub struct S3Src {
settings: Mutex<Settings>, settings: Mutex<Settings>,
state: Mutex<StreamingState>, state: Mutex<StreamingState>,
canceller: Mutex<Option<future::AbortHandle>>, canceller: Mutex<s3utils::Canceller>,
} }
static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| { static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
@ -89,14 +88,6 @@ static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
}); });
impl S3Src { impl S3Src {
fn cancel(&self) {
let mut canceller = self.canceller.lock().unwrap();
if let Some(c) = canceller.take() {
c.abort()
};
}
fn connect(self: &S3Src, url: &GstS3Url) -> Result<Client, gst::ErrorMessage> { fn connect(self: &S3Src, url: &GstS3Url) -> Result<Client, gst::ErrorMessage> {
let settings = self.settings.lock().unwrap(); let settings = self.settings.lock().unwrap();
let timeout_config = s3utils::timeout_config(settings.request_timeout); let timeout_config = s3utils::timeout_config(settings.request_timeout);
@ -521,9 +512,6 @@ impl BaseSrcImpl for S3Src {
} }
fn stop(&self) -> Result<(), gst::ErrorMessage> { fn stop(&self) -> Result<(), gst::ErrorMessage> {
// First, stop any asynchronous tasks if we're running, as they will have the state lock
self.cancel();
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
if let StreamingState::Stopped = *state { if let StreamingState::Stopped = *state {
@ -587,7 +575,14 @@ impl BaseSrcImpl for S3Src {
} }
fn unlock(&self) -> Result<(), gst::ErrorMessage> { fn unlock(&self) -> Result<(), gst::ErrorMessage> {
self.cancel(); let mut canceller = self.canceller.lock().unwrap();
canceller.abort();
Ok(())
}
fn unlock_stop(&self) -> Result<(), gst::ErrorMessage> {
let mut canceller = self.canceller.lock().unwrap();
*canceller = s3utils::Canceller::None;
Ok(()) Ok(())
} }
} }

View file

@ -51,21 +51,38 @@ impl<E: ProvideErrorMetadata + std::error::Error> fmt::Display for WaitError<E>
} }
} }
pub fn wait<F, T, E>( #[derive(Default)]
canceller: &Mutex<Option<future::AbortHandle>>, pub enum Canceller {
future: F, #[default]
) -> Result<T, WaitError<E>> None,
Handle(future::AbortHandle),
Cancelled,
}
impl Canceller {
pub fn abort(&mut self) {
if let Canceller::Handle(ref canceller) = *self {
canceller.abort();
}
*self = Canceller::Cancelled;
}
}
pub fn wait<F, T, E>(canceller_mutex: &Mutex<Canceller>, future: F) -> Result<T, WaitError<E>>
where where
F: Send + Future<Output = Result<T, E>>, F: Send + Future<Output = Result<T, E>>,
F::Output: Send, F::Output: Send,
T: Send, T: Send,
E: Send, E: Send,
{ {
let mut canceller_guard = canceller.lock().unwrap(); let mut canceller = canceller_mutex.lock().unwrap();
if matches!(*canceller, Canceller::Cancelled) {
return Err(WaitError::Cancelled);
}
let (abort_handle, abort_registration) = future::AbortHandle::new_pair(); let (abort_handle, abort_registration) = future::AbortHandle::new_pair();
*canceller = Canceller::Handle(abort_handle);
canceller_guard.replace(abort_handle); drop(canceller);
drop(canceller_guard);
let abortable_future = future::Abortable::new(future, abort_registration); let abortable_future = future::Abortable::new(future, abort_registration);
@ -86,17 +103,21 @@ where
}; };
/* Clear out the canceller */ /* Clear out the canceller */
canceller_guard = canceller.lock().unwrap(); let mut canceller = canceller_mutex.lock().unwrap();
*canceller_guard = None; if matches!(*canceller, Canceller::Cancelled) {
return Err(WaitError::Cancelled);
}
*canceller = Canceller::None;
drop(canceller);
res res
} }
pub fn wait_stream( pub fn wait_stream(
canceller: &Mutex<Option<future::AbortHandle>>, canceller_mutex: &Mutex<Canceller>,
stream: &mut ByteStream, stream: &mut ByteStream,
) -> Result<Bytes, WaitError<ByteStreamError>> { ) -> Result<Bytes, WaitError<ByteStreamError>> {
wait(canceller, async move { wait(canceller_mutex, async move {
let mut collect = BytesMut::new(); let mut collect = BytesMut::new();
// Loop over the stream and collect till we're done // Loop over the stream and collect till we're done
@ -116,7 +137,7 @@ pub fn timeout_config(request_timeout: Duration) -> TimeoutConfig {
} }
pub fn wait_config( pub fn wait_config(
canceller: &Mutex<Option<future::AbortHandle>>, canceller_mutex: &Mutex<Canceller>,
region: Region, region: Region,
timeout_config: TimeoutConfig, timeout_config: TimeoutConfig,
credentials: Option<Credentials>, credentials: Option<Credentials>,
@ -136,11 +157,13 @@ pub fn wait_config(
.load(), .load(),
}; };
let mut canceller_guard = canceller.lock().unwrap(); let mut canceller = canceller_mutex.lock().unwrap();
if matches!(*canceller, Canceller::Cancelled) {
return Err(WaitError::Cancelled);
}
let (abort_handle, abort_registration) = future::AbortHandle::new_pair(); let (abort_handle, abort_registration) = future::AbortHandle::new_pair();
*canceller = Canceller::Handle(abort_handle);
canceller_guard.replace(abort_handle); drop(canceller);
drop(canceller_guard);
let abortable_future = future::Abortable::new(config_future, abort_registration); let abortable_future = future::Abortable::new(config_future, abort_registration);
@ -157,8 +180,12 @@ pub fn wait_config(
}; };
/* Clear out the canceller */ /* Clear out the canceller */
canceller_guard = canceller.lock().unwrap(); let mut canceller = canceller_mutex.lock().unwrap();
*canceller_guard = None; if matches!(*canceller, Canceller::Cancelled) {
return Err(WaitError::Cancelled);
}
*canceller = Canceller::None;
drop(canceller);
res res
} }