diff --git a/generic/threadshare/Cargo.toml b/generic/threadshare/Cargo.toml index 49dfe61e..e9dafbce 100644 --- a/generic/threadshare/Cargo.toml +++ b/generic/threadshare/Cargo.toml @@ -14,8 +14,8 @@ gio = { git = "https://github.com/gtk-rs/gtk-rs-core" } gst = { package = "gstreamer", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } gst-net = { package = "gstreamer-net", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } gst-rtp = { package = "gstreamer-rtp", git = "https://gitlab.freedesktop.org/gstreamer/gstreamer-rs" } -pin-project = "1" once_cell = "1" +pin-project-lite = "0.2.7" tokio = { git = "https://github.com/fengalin/tokio", tag = "tokio-0.2.13-throttling.1", features = ["io-util", "macros", "rt-core", "sync", "stream", "time", "tcp", "udp", "rt-util"] } futures = { version = "0.3", features = ["thread-pool"] } rand = "0.8" diff --git a/generic/threadshare/src/runtime/executor/context.rs b/generic/threadshare/src/runtime/executor/context.rs index 67ae7bf6..33940582 100644 --- a/generic/threadshare/src/runtime/executor/context.rs +++ b/generic/threadshare/src/runtime/executor/context.rs @@ -16,26 +16,18 @@ // Free Software Foundation, Inc., 51 Franklin Street, Suite 500, // Boston, MA 02110-1335, USA. -use futures::channel::oneshot; -use futures::future::BoxFuture; use futures::prelude::*; -use gst::{gst_debug, gst_log, gst_trace, gst_warning}; +use gst::{gst_debug, gst_trace}; use once_cell::sync::Lazy; -use std::cell::RefCell; -use std::collections::{HashMap, VecDeque}; -use std::fmt; +use std::collections::HashMap; use std::io; -use std::mem; -use std::pin::Pin; -use std::sync::mpsc as sync_mpsc; -use std::sync::{Arc, Mutex, Weak}; -use std::task::Poll; -use std::thread; +use std::sync::{Arc, Mutex}; use std::time::Duration; +use super::{Handle, HandleWeak, JoinHandle, Scheduler, SubTaskOutput, TaskId}; use crate::runtime::RUNTIME_CAT; // We are bound to using `sync` for the `runtime` `Mutex`es. Attempts to use `async` `Mutex`es @@ -48,15 +40,9 @@ use crate::runtime::RUNTIME_CAT; // // Also, we want to be able to `acquire` a `Context` outside of an `async` context. // These `Mutex`es must be `lock`ed for a short period. -static CONTEXTS: Lazy>>> = +static CONTEXTS: Lazy, ContextWeak>>> = Lazy::new(|| Mutex::new(HashMap::new())); -thread_local!(static CURRENT_THREAD_CONTEXT: RefCell> = RefCell::new(None)); - -tokio::task_local! { - static CURRENT_TASK_ID: TaskId; -} - /// Blocks on `future` in one way or another if possible. /// /// IO & time related `Future`s must be handled within their own [`Context`]. @@ -98,43 +84,12 @@ pub fn block_on_or_add_sub_task(future: Fut) -> Op /// # Panics /// /// This function panics if called within a [`Context`] thread. -pub fn block_on(future: Fut) -> Fut::Output { +pub fn block_on(future: F) -> F::Output { assert!(!Context::is_context_thread()); // Not running in a Context thread so we can block gst_debug!(RUNTIME_CAT, "Blocking on new dummy context"); - - let context = Context(Arc::new(ContextInner { - real: None, - task_queues: Mutex::new((0, HashMap::new())), - })); - - CURRENT_THREAD_CONTEXT.with(move |cur_ctx| { - *cur_ctx.borrow_mut() = Some(context.downgrade()); - - let res = futures::executor::block_on(async move { - CURRENT_TASK_ID - .scope(TaskId(0), async move { - let task_id = CURRENT_TASK_ID.try_with(|task_id| *task_id).ok(); - assert_eq!(task_id, Some(TaskId(0))); - - let res = future.await; - - while Context::current_has_sub_tasks() { - if Context::drain_sub_tasks().await.is_err() { - break; - } - } - - res - }) - .await - }); - - *cur_ctx.borrow_mut() = None; - - res - }) + Scheduler::block_on(future) } /// Yields execution back to the runtime @@ -143,222 +98,8 @@ pub async fn yield_now() { tokio::task::yield_now().await; } -struct ContextThread { - name: String, -} - -impl ContextThread { - fn start(name: &str, wait: Duration) -> Context { - let context_thread = ContextThread { name: name.into() }; - let (context_sender, context_receiver) = sync_mpsc::channel(); - let join = thread::spawn(move || { - context_thread.spawn(wait, context_sender); - }); - - let context = context_receiver.recv().expect("Context thread init failed"); - *context - .0 - .real - .as_ref() - .unwrap() - .shutdown - .join - .lock() - .unwrap() = Some(join); - - context - } - - fn spawn(&self, wait: Duration, context_sender: sync_mpsc::Sender) { - gst_debug!(RUNTIME_CAT, "Started context thread '{}'", self.name); - - let mut runtime = tokio::runtime::Builder::new() - .basic_scheduler() - .thread_name(self.name.clone()) - .enable_all() - .max_throttling(wait) - .build() - .expect("Couldn't build the runtime"); - - let (shutdown_sender, shutdown_receiver) = oneshot::channel(); - - let shutdown = ContextShutdown { - name: self.name.clone(), - shutdown: Some(shutdown_sender), - join: Mutex::new(None), - }; - - let context = Context(Arc::new(ContextInner { - real: Some(ContextRealInner { - name: self.name.clone(), - wait_duration: wait, - handle: Mutex::new(runtime.handle().clone()), - shutdown, - }), - task_queues: Mutex::new((0, HashMap::new())), - })); - - CURRENT_THREAD_CONTEXT.with(|cur_ctx| { - *cur_ctx.borrow_mut() = Some(context.downgrade()); - }); - - context_sender.send(context).unwrap(); - - let _ = runtime.block_on(shutdown_receiver); - } -} - -impl Drop for ContextThread { - fn drop(&mut self) { - gst_debug!(RUNTIME_CAT, "Terminated: context thread '{}'", self.name); - } -} - -#[derive(Debug)] -struct ContextShutdown { - name: String, - shutdown: Option>, - join: Mutex>>, -} - -impl Drop for ContextShutdown { - fn drop(&mut self) { - gst_debug!( - RUNTIME_CAT, - "Shutting down context thread thread '{}'", - self.name - ); - self.shutdown.take().unwrap(); - - gst_trace!( - RUNTIME_CAT, - "Waiting for context thread '{}' to shutdown", - self.name - ); - let join_handle = self.join.lock().unwrap().take().unwrap(); - let _ = join_handle.join(); - } -} - -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] -pub struct TaskId(u64); - -pub type SubTaskOutput = Result<(), gst::FlowError>; -pub struct SubTaskQueue(VecDeque>); - -impl fmt::Debug for SubTaskQueue { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_tuple("SubTaskQueue").finish() - } -} - -pub struct JoinError(tokio::task::JoinError); - -impl fmt::Display for JoinError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, fmt) - } -} - -impl fmt::Debug for JoinError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(&self.0, fmt) - } -} - -impl std::error::Error for JoinError {} - -impl From for JoinError { - fn from(src: tokio::task::JoinError) -> Self { - JoinError(src) - } -} - -/// Wrapper for the underlying runtime JoinHandle implementation. -pub struct JoinHandle { - join_handle: tokio::task::JoinHandle, - context: ContextWeak, - task_id: TaskId, -} - -unsafe impl Send for JoinHandle {} -unsafe impl Sync for JoinHandle {} - -impl JoinHandle { - pub fn is_current(&self) -> bool { - if let Some((context, task_id)) = Context::current_task() { - let self_context = self.context.upgrade(); - self_context.map(|c| c == context).unwrap_or(false) && task_id == self.task_id - } else { - false - } - } - - pub fn context(&self) -> Option { - self.context.upgrade() - } - - pub fn task_id(&self) -> TaskId { - self.task_id - } -} - -impl Unpin for JoinHandle {} - -impl Future for JoinHandle { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - if self.as_ref().is_current() { - panic!("Trying to join task {:?} from itself", self.as_ref()); - } - - self.as_mut() - .join_handle - .poll_unpin(cx) - .map_err(JoinError::from) - } -} - -impl fmt::Debug for JoinHandle { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let context_name = self.context.upgrade().map(|c| String::from(c.name())); - - fmt.debug_struct("JoinHandle") - .field("context", &context_name) - .field("task_id", &self.task_id) - .finish() - } -} - -#[derive(Debug)] -struct ContextRealInner { - name: String, - handle: Mutex, - wait_duration: Duration, - // Only used for dropping - shutdown: ContextShutdown, -} - -#[derive(Debug)] -struct ContextInner { - // Otherwise a dummy context - real: Option, - task_queues: Mutex<(u64, HashMap)>, -} - -impl Drop for ContextInner { - fn drop(&mut self) { - if let Some(ref real) = self.real { - let mut contexts = CONTEXTS.lock().unwrap(); - gst_debug!(RUNTIME_CAT, "Finalizing context '{}'", real.name); - contexts.remove(&real.name); - } - } -} - #[derive(Clone, Debug)] -pub struct ContextWeak(Weak); +pub struct ContextWeak(HandleWeak); impl ContextWeak { pub fn upgrade(&self) -> Option { @@ -374,16 +115,14 @@ impl ContextWeak { /// `Element` implementations should use [`PadSrc`] and [`PadSink`] which /// provide high-level features. /// -/// See the [module-level documentation](index.html) for more. -/// /// [`PadSrc`]: ../pad/struct.PadSrc.html /// [`PadSink`]: ../pad/struct.PadSink.html #[derive(Clone, Debug)] -pub struct Context(Arc); +pub struct Context(Handle); impl PartialEq for Context { fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.0, &other.0) + self.0.eq(&other.0) } } @@ -391,90 +130,59 @@ impl Eq for Context {} impl Context { pub fn acquire(context_name: &str, wait: Duration) -> Result { - assert_ne!(context_name, "DUMMY"); + assert_ne!(context_name, Scheduler::DUMMY_NAME); let mut contexts = CONTEXTS.lock().unwrap(); - if let Some(inner_weak) = contexts.get(context_name) { - if let Some(inner_strong) = inner_weak.upgrade() { - gst_debug!( - RUNTIME_CAT, - "Joining Context '{}'", - inner_strong.real.as_ref().unwrap().name - ); - return Ok(Context(inner_strong)); + if let Some(context_weak) = contexts.get(context_name) { + if let Some(context) = context_weak.upgrade() { + gst_debug!(RUNTIME_CAT, "Joining Context '{}'", context.name()); + return Ok(context); } } - let context = ContextThread::start(context_name, wait); - contexts.insert(context_name.into(), Arc::downgrade(&context.0)); + let context = Context(Scheduler::start(context_name, wait)); + contexts.insert(context_name.into(), context.downgrade()); - gst_debug!( - RUNTIME_CAT, - "New Context '{}'", - context.0.real.as_ref().unwrap().name - ); + gst_debug!(RUNTIME_CAT, "New Context '{}'", context.name()); Ok(context) } pub fn downgrade(&self) -> ContextWeak { - ContextWeak(Arc::downgrade(&self.0)) + ContextWeak(self.0.downgrade()) } pub fn name(&self) -> &str { - match self.0.real { - Some(ref real) => real.name.as_str(), - None => "DUMMY", - } + self.0.context_name() } + // FIXME this could be renamed as max_throttling + // but then, all elements should also change their + // wait variables and properties to max_throttling. pub fn wait_duration(&self) -> Duration { - match self.0.real { - Some(ref real) => real.wait_duration, - None => Duration::ZERO, - } + self.0.max_throttling() } /// Returns `true` if a `Context` is running on current thread. pub fn is_context_thread() -> bool { - CURRENT_THREAD_CONTEXT.with(|cur_ctx| cur_ctx.borrow().is_some()) + Scheduler::is_scheduler_thread() } /// Returns the `Context` running on current thread, if any. pub fn current() -> Option { - CURRENT_THREAD_CONTEXT.with(|cur_ctx| { - cur_ctx - .borrow() - .as_ref() - .and_then(|ctx_weak| ctx_weak.upgrade()) - }) + Scheduler::current().map(Context) } /// Returns the `TaskId` running on current thread, if any. pub fn current_task() -> Option<(Context, TaskId)> { - CURRENT_THREAD_CONTEXT.with(|cur_ctx| { - cur_ctx - .borrow() - .as_ref() - .and_then(|ctx_weak| ctx_weak.upgrade()) - .and_then(|ctx| { - let task_id = CURRENT_TASK_ID.try_with(|task_id| *task_id).ok(); - - task_id.map(move |task_id| (ctx, task_id)) - }) - }) + Scheduler::current().map(Context).zip(TaskId::current()) } pub fn enter(&self, f: F) -> R where F: FnOnce() -> R, { - let real = match self.0.real { - Some(ref real) => real, - None => panic!("Can't enter on dummy context"), - }; - - real.handle.lock().unwrap().enter(f) + self.0.enter(f) } pub fn spawn(&self, future: Fut) -> JoinHandle @@ -482,7 +190,7 @@ impl Context { Fut: Future + Send + 'static, Fut::Output: Send + 'static, { - self.spawn_internal(future, false) + self.0.spawn(future, false) } pub fn awake_and_spawn(&self, future: Fut) -> JoinHandle @@ -490,80 +198,7 @@ impl Context { Fut: Future + Send + 'static, Fut::Output: Send + 'static, { - self.spawn_internal(future, true) - } - - #[inline] - fn spawn_internal(&self, future: Fut, must_awake: bool) -> JoinHandle - where - Fut: Future + Send + 'static, - Fut::Output: Send + 'static, - { - let real = match self.0.real { - Some(ref real) => real, - None => panic!("Can't spawn new tasks on dummy context"), - }; - - let mut task_queues = self.0.task_queues.lock().unwrap(); - let id = task_queues.0; - task_queues.0 += 1; - task_queues.1.insert(id, SubTaskQueue(VecDeque::new())); - - let id = TaskId(id); - gst_trace!( - RUNTIME_CAT, - "Spawning new task {:?} on context {}", - id, - real.name - ); - - let spawn_fut = async move { - let ctx = Context::current().unwrap(); - let real = ctx.0.real.as_ref().unwrap(); - - gst_trace!( - RUNTIME_CAT, - "Running task {:?} on context {}", - id, - real.name - ); - let res = CURRENT_TASK_ID.scope(id, future).await; - - // Remove task from the list - { - let mut task_queues = ctx.0.task_queues.lock().unwrap(); - if let Some(task_queue) = task_queues.1.remove(&id.0) { - let l = task_queue.0.len(); - if l > 0 { - gst_warning!( - RUNTIME_CAT, - "Task {:?} on context {} has {} pending sub tasks", - id, - real.name, - l - ); - } - } - } - - gst_trace!(RUNTIME_CAT, "Task {:?} on context {} done", id, real.name); - - res - }; - - let join_handle = { - if must_awake { - real.handle.lock().unwrap().awake_and_spawn(spawn_fut) - } else { - real.handle.lock().unwrap().spawn(spawn_fut) - } - }; - - JoinHandle { - join_handle, - context: self.downgrade(), - task_id: id, - } + self.0.spawn(future, true) } pub fn current_has_sub_tasks() -> bool { @@ -575,12 +210,7 @@ impl Context { } }; - let task_queues = ctx.0.task_queues.lock().unwrap(); - task_queues - .1 - .get(&task_id.0) - .map(|t| !t.0.is_empty()) - .unwrap_or(false) + ctx.0.has_sub_tasks(task_id) } pub fn add_sub_task(sub_task: T) -> Result<(), T> @@ -595,31 +225,7 @@ impl Context { } }; - let mut task_queues = ctx.0.task_queues.lock().unwrap(); - match task_queues.1.get_mut(&task_id.0) { - Some(task_queue) => { - if let Some(ref real) = ctx.0.real { - gst_trace!( - RUNTIME_CAT, - "Adding subtask to {:?} on context {}", - task_id, - real.name - ); - } else { - gst_trace!( - RUNTIME_CAT, - "Adding subtask to {:?} on dummy context", - task_id, - ); - } - task_queue.0.push_back(sub_task.boxed()); - Ok(()) - } - None => { - gst_trace!(RUNTIME_CAT, "Task was removed in the meantime"); - Err(sub_task) - } - } + ctx.0.add_sub_task(task_id, sub_task) } pub async fn drain_sub_tasks() -> SubTaskOutput { @@ -628,45 +234,13 @@ impl Context { None => return Ok(()), }; - ctx.drain_sub_tasks_internal(task_id).await + ctx.0.drain_sub_tasks(task_id).await } +} - fn drain_sub_tasks_internal( - &self, - id: TaskId, - ) -> impl Future + Send + 'static { - let mut task_queue = { - let mut task_queues = self.0.task_queues.lock().unwrap(); - if let Some(task_queue) = task_queues.1.get_mut(&id.0) { - mem::replace(task_queue, SubTaskQueue(VecDeque::new())) - } else { - SubTaskQueue(VecDeque::new()) - } - }; - - let name = self - .0 - .real - .as_ref() - .map(|r| r.name.clone()) - .unwrap_or_else(|| String::from("DUMMY")); - async move { - if !task_queue.0.is_empty() { - gst_log!( - RUNTIME_CAT, - "Scheduling draining {} sub tasks from {:?} on '{}'", - task_queue.0.len(), - id, - &name, - ); - - for task in task_queue.0.drain(..) { - task.await?; - } - } - - Ok(()) - } +impl From for Context { + fn from(handle: Handle) -> Self { + Context(handle) } } @@ -680,6 +254,7 @@ mod tests { use std::sync::Arc; use std::time::{Duration, Instant}; + use super::super::Scheduler; use super::Context; type Item = i32; @@ -692,19 +267,36 @@ mod tests { fn block_on_task_id() { gst::init().unwrap(); + assert!(!Context::is_context_thread()); + crate::runtime::executor::block_on(async { - let (_ctx, task_id) = Context::current_task().unwrap(); + let (ctx, task_id) = Context::current_task().unwrap(); + assert_eq!(ctx.name(), Scheduler::DUMMY_NAME); assert_eq!(task_id, super::TaskId(0)); - /* Adding the sub task fails let res = Context::add_sub_task(async move { let (_ctx, task_id) = Context::current_task().unwrap(); assert_eq!(task_id, super::TaskId(0)); Ok(()) }); assert!(res.is_ok()); - */ + assert!(Context::is_context_thread()); }); + + assert!(!Context::is_context_thread()); + } + + #[test] + fn block_on_timer() { + gst::init().unwrap(); + + let elapsed = crate::runtime::executor::block_on(async { + let now = Instant::now(); + crate::runtime::time::delay_for(DELAY).await; + now.elapsed() + }); + + assert!(elapsed >= DELAY); } #[test] @@ -713,7 +305,8 @@ mod tests { let context = Context::acquire("context_task_id", SLEEP_DURATION).unwrap(); let join_handle = context.spawn(async { - let (_ctx, task_id) = Context::current_task().unwrap(); + let (ctx, task_id) = Context::current_task().unwrap(); + assert_eq!(ctx.name(), "context_task_id"); assert_eq!(task_id, super::TaskId(0)); }); futures::executor::block_on(join_handle).unwrap(); @@ -810,7 +403,10 @@ mod tests { let mut receiver = futures::executor::block_on(join_handle).unwrap(); // The last sub task should be simply dropped at this point - assert_eq!(receiver.try_next().unwrap(), None); + match receiver.try_next() { + Ok(None) | Err(_) => (), + other => panic!("Unexpected {:?}", other), + } } #[test] @@ -853,30 +449,33 @@ mod tests { futures::executor::block_on(join_handle).unwrap_err(); } - #[tokio::test] - async fn enter_context_from_tokio() { + #[test] + fn enter_context_from_scheduler() { gst::init().unwrap(); - let context = Context::acquire("enter_context_from_tokio", SLEEP_DURATION).unwrap(); - let mut socket = context - .enter(|| { - let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5002); - let socket = UdpSocket::bind(saddr).unwrap(); - tokio::net::UdpSocket::from_std(socket) - }) - .unwrap(); + let elapsed = crate::runtime::executor::block_on(async { + let context = Context::acquire("enter_context_from_tokio", SLEEP_DURATION).unwrap(); + let mut socket = context + .enter(|| { + let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5002); + let socket = UdpSocket::bind(saddr).unwrap(); + tokio::net::UdpSocket::from_std(socket) + }) + .unwrap(); - let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000); - let bytes_sent = socket.send_to(&[0; 10], saddr).await.unwrap(); - assert_eq!(bytes_sent, 10); + let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4000); + let bytes_sent = socket.send_to(&[0; 10], saddr).await.unwrap(); + assert_eq!(bytes_sent, 10); - let elapsed = context.enter(|| { - futures::executor::block_on(async { - let now = Instant::now(); - crate::runtime::time::delay_for(DELAY).await; - now.elapsed() + context.enter(|| { + futures::executor::block_on(async { + let now = Instant::now(); + crate::runtime::time::delay_for(DELAY).await; + now.elapsed() + }) }) }); + // Due to throttling, `Delay` may be fired earlier assert!(elapsed + SLEEP_DURATION / 2 >= DELAY); } diff --git a/generic/threadshare/src/runtime/executor/join.rs b/generic/threadshare/src/runtime/executor/join.rs new file mode 100644 index 00000000..e17229f3 --- /dev/null +++ b/generic/threadshare/src/runtime/executor/join.rs @@ -0,0 +1,106 @@ +// Copyright (C) 2018-2020 Sebastian Dröge +// Copyright (C) 2019-2021 François Laignel +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Library General Public +// License as published by the Free Software Foundation; either +// version 2 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Library General Public License for more details. +// +// You should have received a copy of the GNU Library General Public +// License along with this library; if not, write to the +// Free Software Foundation, Inc., 51 Franklin Street, Suite 500, +// Boston, MA 02110-1335, USA. + +use futures::channel::oneshot; +use futures::prelude::*; + +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::Poll; + +use super::context::Context; +use super::TaskId; +use super::{Handle, HandleWeak, Scheduler}; + +#[derive(Debug)] +pub struct JoinError(TaskId); + +impl fmt::Display for JoinError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?} was Canceled", self.0) + } +} + +impl std::error::Error for JoinError {} + +pub struct JoinHandle { + receiver: oneshot::Receiver, + handle: HandleWeak, + task_id: TaskId, +} + +unsafe impl Send for JoinHandle {} +unsafe impl Sync for JoinHandle {} + +impl JoinHandle { + pub(super) fn new(receiver: oneshot::Receiver, handle: &Handle, task_id: TaskId) -> Self { + JoinHandle { + receiver, + handle: handle.downgrade(), + task_id, + } + } + + pub fn is_current(&self) -> bool { + if let Some((cur_scheduler, task_id)) = Scheduler::current().zip(TaskId::current()) { + self.handle.upgrade().map_or(false, |self_scheduler| { + self_scheduler == cur_scheduler && task_id == self.task_id + }) + } else { + false + } + } + + pub fn context(&self) -> Option { + self.handle.upgrade().map(Context::from) + } + + pub fn task_id(&self) -> TaskId { + self.task_id + } +} + +impl Future for JoinHandle { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + if self.as_ref().is_current() { + panic!("Trying to join task {:?} from itself", self.as_ref()); + } + + self.as_mut() + .receiver + .poll_unpin(cx) + .map_err(|_| JoinError(self.task_id)) + } +} + +impl fmt::Debug for JoinHandle { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let context_name = self + .handle + .upgrade() + .map(|handle| handle.context_name().to_owned()); + + fmt.debug_struct("JoinHandle") + .field("context", &context_name) + .field("task_id", &self.task_id) + .finish() + } +} diff --git a/generic/threadshare/src/runtime/executor/mod.rs b/generic/threadshare/src/runtime/executor/mod.rs index 48395e5f..0cd3c6f3 100644 --- a/generic/threadshare/src/runtime/executor/mod.rs +++ b/generic/threadshare/src/runtime/executor/mod.rs @@ -34,6 +34,27 @@ //! [`PadSink`]: ../pad/struct.PadSink.html mod context; -pub use context::{ - block_on, block_on_or_add_sub_task, yield_now, Context, JoinHandle, SubTaskOutput, TaskId, -}; +pub use context::{block_on, block_on_or_add_sub_task, yield_now, Context}; + +mod scheduler; +use scheduler::{Handle, HandleWeak, Scheduler}; + +mod join; +pub use join::JoinHandle; + +mod task; +pub use task::{SubTaskOutput, TaskId}; + +struct CallOnDrop(Option); + +impl CallOnDrop { + fn new(f: F) -> Self { + CallOnDrop(Some(f)) + } +} + +impl Drop for CallOnDrop { + fn drop(&mut self) { + self.0.take().unwrap()() + } +} diff --git a/generic/threadshare/src/runtime/executor/scheduler.rs b/generic/threadshare/src/runtime/executor/scheduler.rs new file mode 100644 index 00000000..e39e636e --- /dev/null +++ b/generic/threadshare/src/runtime/executor/scheduler.rs @@ -0,0 +1,319 @@ +// Copyright (C) 2018-2020 Sebastian Dröge +// Copyright (C) 2019-2021 François Laignel +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Library General Public +// License as published by the Free Software Foundation; either +// version 2 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Library General Public License for more details. +// +// You should have received a copy of the GNU Library General Public +// License along with this library; if not, write to the +// Free Software Foundation, Inc., 51 Franklin Street, Suite 500, +// Boston, MA 02110-1335, USA. + +use futures::channel::oneshot; + +use gio::glib::clone::Downgrade; +use gst::{gst_debug, gst_trace}; + +use std::cell::RefCell; +use std::future::Future; +use std::sync::mpsc as sync_mpsc; +use std::sync::{Arc, Mutex, Weak}; +use std::thread; +use std::time::Duration; + +use super::task::{SubTaskOutput, TaskFuture, TaskId, TaskQueue}; +use super::{CallOnDrop, JoinHandle}; +use crate::runtime::RUNTIME_CAT; + +thread_local! { + static CURRENT_SCHEDULER: RefCell>> = RefCell::new(None); +} + +#[derive(Debug)] +pub(super) struct Scheduler { + context_name: Arc, + max_throttling: Duration, + task_queue: Mutex, + rt_handle: Mutex, + shutdown: Mutex>, +} + +impl Scheduler { + pub const DUMMY_NAME: &'static str = "DUMMY"; + + pub fn start(context_name: &str, max_throttling: Duration) -> Handle { + let context_name = Arc::from(context_name); + + let (handle_sender, handle_receiver) = sync_mpsc::channel(); + let (shutdown_sender, shutdown_receiver) = oneshot::channel(); + let thread_ctx_name = Arc::clone(&context_name); + let join = thread::spawn(move || { + gst_debug!( + RUNTIME_CAT, + "Started Scheduler thread for Context '{}'", + thread_ctx_name + ); + + let (mut rt, handle) = Scheduler::init(thread_ctx_name, max_throttling); + handle_sender.send(handle.clone()).unwrap(); + + let _ = rt.block_on(shutdown_receiver); + }); + + let handle = handle_receiver.recv().expect("Context thread init failed"); + *handle.0.shutdown.lock().unwrap() = Some(SchedulerShutdown { + context_name, + sender: Some(shutdown_sender), + join: Some(join), + }); + + handle + } + + fn init(context_name: Arc, max_throttling: Duration) -> (tokio::runtime::Runtime, Handle) { + let runtime = tokio::runtime::Builder::new() + .basic_scheduler() + .enable_all() + .max_throttling(max_throttling) + .build() + .expect("Couldn't build the runtime"); + + let scheduler = Arc::new(Scheduler { + context_name: context_name.clone(), + max_throttling, + task_queue: Mutex::new(TaskQueue::new(context_name)), + rt_handle: Mutex::new(runtime.handle().clone()), + shutdown: Mutex::new(None), + }); + + CURRENT_SCHEDULER.with(|cur_scheduler| { + *cur_scheduler.borrow_mut() = Some(scheduler.downgrade()); + }); + + (runtime, scheduler.into()) + } + + pub fn block_on(future: F) -> ::Output { + assert!( + !Scheduler::is_scheduler_thread(), + "Attempt at blocking on from an existing Scheduler thread." + ); + let (mut rt, handle) = Scheduler::init(Scheduler::DUMMY_NAME.into(), Duration::ZERO); + + let handle_clone = handle.clone(); + let task = handle.0.task_queue.lock().unwrap().add(async move { + let res = future.await; + + let task_id = TaskId::current().unwrap(); + while handle_clone.has_sub_tasks(task_id) { + if handle_clone.drain_sub_tasks(task_id).await.is_err() { + break; + } + } + + res + }); + + let task_id = task.id(); + gst_trace!(RUNTIME_CAT, "Blocking on current thread with {:?}", task_id,); + + let _guard = CallOnDrop::new(|| { + gst_trace!( + RUNTIME_CAT, + "Blocking on current thread with {:?} done", + task_id, + ); + + handle.remove_task(task_id); + }); + + rt.block_on(task) + } + + pub(super) fn is_scheduler_thread() -> bool { + CURRENT_SCHEDULER.with(|cur_scheduler| cur_scheduler.borrow().is_some()) + } + + pub(super) fn current() -> Option { + CURRENT_SCHEDULER.with(|cur_scheduler| { + cur_scheduler + .borrow() + .as_ref() + .and_then(Weak::upgrade) + .map(Handle::from) + }) + } +} + +impl Drop for Scheduler { + fn drop(&mut self) { + // No more strong handlers point to this + // Scheduler, so remove its thread local key. + let _ = CURRENT_SCHEDULER.try_with(|cur_scheduler| { + *cur_scheduler.borrow_mut() = None; + }); + + gst_debug!( + RUNTIME_CAT, + "Terminated: Scheduler for Context '{}'", + self.context_name + ); + } +} + +#[derive(Debug)] +pub(super) struct SchedulerShutdown { + context_name: Arc, + sender: Option>, + join: Option>, +} + +impl Drop for SchedulerShutdown { + fn drop(&mut self) { + gst_debug!( + RUNTIME_CAT, + "Shutting down Scheduler thread for Context '{}'", + self.context_name + ); + self.sender.take().unwrap(); + + gst_trace!( + RUNTIME_CAT, + "Waiting for Scheduler to shutdown for Context '{}'", + self.context_name + ); + let _ = self.join.take().unwrap().join(); + } +} + +#[derive(Clone, Debug)] +pub(super) struct HandleWeak(Weak); + +impl HandleWeak { + pub(super) fn upgrade(&self) -> Option { + self.0.upgrade().map(Handle) + } +} + +#[derive(Clone, Debug)] +pub(super) struct Handle(Arc); + +impl Handle { + pub fn context_name(&self) -> &str { + &self.0.context_name + } + + pub fn max_throttling(&self) -> Duration { + self.0.max_throttling + } + + pub fn enter(&self, f: F) -> R + where + F: FnOnce() -> R, + { + self.0.rt_handle.lock().unwrap().enter(f) + } + + pub fn add_task(&self, future: F) -> TaskFuture { + let task = self.0.task_queue.lock().unwrap().add(future); + task + } + + pub fn remove_task(&self, task_id: TaskId) { + self.0.task_queue.lock().unwrap().remove(task_id); + } + + pub fn spawn(&self, future: F, must_awake: bool) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let task = self.add_task(future); + let task_id = task.id(); + let (sender, receiver) = oneshot::channel(); + + gst_trace!( + RUNTIME_CAT, + "Spawning new task_id {:?} on context {}", + task.id(), + self.0.context_name + ); + + let this = self.clone(); + let spawn_fut = async move { + gst_trace!( + RUNTIME_CAT, + "Running task_id {:?} on context {}", + task_id, + this.context_name() + ); + + let _guard = CallOnDrop::new(|| { + gst_trace!( + RUNTIME_CAT, + "Task {:?} on context {} done", + task_id, + this.context_name() + ); + + this.0.task_queue.lock().unwrap().remove(task_id); + }); + + let _ = sender.send(task.await); + }; + + if must_awake { + let _ = self.0.rt_handle.lock().unwrap().awake_and_spawn(spawn_fut); + } else { + let _ = self.0.rt_handle.lock().unwrap().spawn(spawn_fut); + } + + JoinHandle::new(receiver, self, task_id) + } + + pub fn has_sub_tasks(&self, task_id: TaskId) -> bool { + let ret = self.0.task_queue.lock().unwrap().has_sub_tasks(task_id); + ret + } + + pub fn add_sub_task(&self, task_id: TaskId, sub_task: T) -> Result<(), T> + where + T: Future + Send + 'static, + { + let res = self + .0 + .task_queue + .lock() + .unwrap() + .add_sub_task(task_id, sub_task); + res + } + + pub fn downgrade(&self) -> HandleWeak { + HandleWeak(self.0.downgrade()) + } + + pub async fn drain_sub_tasks(&self, task_id: TaskId) -> SubTaskOutput { + let sub_tasks_fut = self.0.task_queue.lock().unwrap().drain_sub_tasks(task_id); + sub_tasks_fut.await + } +} + +impl From> for Handle { + fn from(arc: Arc) -> Self { + Handle(arc) + } +} + +impl PartialEq for Handle { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} diff --git a/generic/threadshare/src/runtime/executor/task.rs b/generic/threadshare/src/runtime/executor/task.rs new file mode 100644 index 00000000..57c6b185 --- /dev/null +++ b/generic/threadshare/src/runtime/executor/task.rs @@ -0,0 +1,229 @@ +// Copyright (C) 2018-2020 Sebastian Dröge +// Copyright (C) 2019-2021 François Laignel +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Library General Public +// License as published by the Free Software Foundation; either +// version 2 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Library General Public License for more details. +// +// You should have received a copy of the GNU Library General Public +// License along with this library; if not, write to the +// Free Software Foundation, Inc., 51 Franklin Street, Suite 500, +// Boston, MA 02110-1335, USA. + +use futures::future::BoxFuture; +use futures::prelude::*; + +use gst::{gst_log, gst_trace, gst_warning}; + +use pin_project_lite::pin_project; + +use std::cell::Cell; +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; + +use crate::runtime::RUNTIME_CAT; + +thread_local! { + static CURRENT_TASK_ID: Cell> = Cell::new(None); +} + +#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] +pub struct TaskId(pub(super) u64); + +impl TaskId { + const LAST: TaskId = TaskId(u64::MAX); + + fn next(task_id: Self) -> Self { + TaskId(task_id.0.wrapping_add(1)) + } + + pub(super) fn current() -> Option { + CURRENT_TASK_ID.try_with(Cell::get).ok().flatten() + } +} + +pub type SubTaskOutput = Result<(), gst::FlowError>; + +pin_project! { + pub(super) struct TaskFuture { + id: TaskId, + #[pin] + future: F, + } + +} + +impl TaskFuture { + pub fn id(&self) -> TaskId { + self.id + } +} + +impl Future for TaskFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + struct TaskIdGuard { + prev_task_id: Option, + } + + impl Drop for TaskIdGuard { + fn drop(&mut self) { + let _ = CURRENT_TASK_ID.try_with(|cur| cur.replace(self.prev_task_id.take())); + } + } + + let task_id = self.id; + let project = self.project(); + + let _guard = TaskIdGuard { + prev_task_id: CURRENT_TASK_ID.with(|cur| cur.replace(Some(task_id))), + }; + + project.future.poll(cx) + } +} + +struct Task { + id: TaskId, + sub_tasks: VecDeque>, +} + +impl Task { + fn new(id: TaskId) -> Self { + Task { + id, + sub_tasks: VecDeque::new(), + } + } + + fn add_sub_task(&mut self, sub_task: T) + where + T: Future + Send + 'static, + { + self.sub_tasks.push_back(sub_task.boxed()); + } + + fn drain_sub_tasks(&mut self) -> VecDeque> { + std::mem::take(&mut self.sub_tasks) + } +} + +impl fmt::Debug for Task { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Task") + .field("id", &self.id) + .field("sub_tasks len", &self.sub_tasks.len()) + .finish() + } +} + +#[derive(Debug)] +pub(super) struct TaskQueue { + last_task_id: TaskId, + tasks: HashMap, + context_name: Arc, +} + +impl TaskQueue { + pub fn new(context_name: Arc) -> Self { + TaskQueue { + last_task_id: TaskId::LAST, + tasks: HashMap::default(), + context_name, + } + } + + pub fn add(&mut self, future: F) -> TaskFuture { + self.last_task_id = TaskId::next(self.last_task_id); + self.tasks + .insert(self.last_task_id, Task::new(self.last_task_id)); + + TaskFuture { + id: self.last_task_id, + future, + } + } + + pub fn remove(&mut self, task_id: TaskId) { + if let Some(task) = self.tasks.remove(&task_id) { + if !task.sub_tasks.is_empty() { + gst_warning!( + RUNTIME_CAT, + "Task {:?} on context {} has {} pending sub tasks", + task_id, + self.context_name, + task.sub_tasks.len(), + ); + } + } + } + + pub fn has_sub_tasks(&self, task_id: TaskId) -> bool { + self.tasks + .get(&task_id) + .map(|t| !t.sub_tasks.is_empty()) + .unwrap_or(false) + } + + pub fn add_sub_task(&mut self, task_id: TaskId, sub_task: T) -> Result<(), T> + where + T: Future + Send + 'static, + { + match self.tasks.get_mut(&task_id) { + Some(task) => { + gst_trace!( + RUNTIME_CAT, + "Adding subtask to {:?} on context {}", + task_id, + self.context_name + ); + task.add_sub_task(sub_task); + Ok(()) + } + None => { + gst_trace!(RUNTIME_CAT, "Task was removed in the meantime"); + Err(sub_task) + } + } + } + + pub fn drain_sub_tasks( + &mut self, + task_id: TaskId, + ) -> impl Future + Send + 'static { + let sub_tasks = self + .tasks + .get_mut(&task_id) + .map(|task| (task.drain_sub_tasks(), Arc::clone(&self.context_name))); + + async move { + if let Some((mut sub_tasks, context_name)) = sub_tasks { + if !sub_tasks.is_empty() { + gst_log!( + RUNTIME_CAT, + "Scheduling draining {} sub tasks from {:?} on '{}'", + sub_tasks.len(), + task_id, + &context_name, + ); + + for sub_task in sub_tasks.drain(..) { + sub_task.await?; + } + } + } + + Ok(()) + } + } +} diff --git a/generic/threadshare/src/runtime/mod.rs b/generic/threadshare/src/runtime/mod.rs index d22fe802..16430c17 100644 --- a/generic/threadshare/src/runtime/mod.rs +++ b/generic/threadshare/src/runtime/mod.rs @@ -31,6 +31,7 @@ //! See this [talk] ([slides]) for a presentation of the motivations and principles, //! and this [blog post]. //! +//! FIXME change this. //! Current implementation uses the crate [`tokio`]. //! //! Most `Element`s implementations should use the high-level features provided by [`PadSrc`] &