gst-plugins-rs/generic/threadshare/src/runtime/executor/task.rs
François Laignel 5e4fc8b138 ts/executor: relax the static bound on enter
The function `enter` is executed in a blocking way from the caller's
point of view. This means that we can guaranty that the provided
function and its output will outlive the underlying Scheduler Task
execution. This requires an unsafe call to
`async_task::spawn_unchecked`. See:

https://docs.rs/async-task/latest/async_task/fn.spawn_unchecked.html
2021-12-25 11:25:56 +00:00

308 lines
8.5 KiB
Rust

// Copyright (C) 2018-2020 Sebastian Dröge <sebastian@centricular.com>
// Copyright (C) 2019-2021 François Laignel <fengalin@free.fr>
//
// Take a look at the license at the top of the repository in the LICENSE file.
use async_task::Runnable;
use concurrent_queue::ConcurrentQueue;
use futures::future::BoxFuture;
use futures::prelude::*;
use gst::{gst_log, gst_trace, gst_warning};
use pin_project_lite::pin_project;
use slab::Slab;
use std::cell::Cell;
use std::collections::VecDeque;
use std::fmt;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::Poll;
use super::CallOnDrop;
use crate::runtime::RUNTIME_CAT;
thread_local! {
static CURRENT_TASK_ID: Cell<Option<TaskId>> = Cell::new(None);
}
#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
pub struct TaskId(pub(super) usize);
impl TaskId {
pub(super) fn current() -> Option<TaskId> {
CURRENT_TASK_ID.try_with(Cell::get).ok().flatten()
}
}
pub type SubTaskOutput = Result<(), gst::FlowError>;
pin_project! {
pub(super) struct TaskFuture<F: Future> {
id: TaskId,
#[pin]
future: F,
}
}
impl<F: Future> Future for TaskFuture<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
struct TaskIdGuard {
prev_task_id: Option<TaskId>,
}
impl Drop for TaskIdGuard {
fn drop(&mut self) {
let _ = CURRENT_TASK_ID.try_with(|cur| cur.replace(self.prev_task_id.take()));
}
}
let task_id = self.id;
let project = self.project();
let _guard = TaskIdGuard {
prev_task_id: CURRENT_TASK_ID.with(|cur| cur.replace(Some(task_id))),
};
project.future.poll(cx)
}
}
struct Task {
id: TaskId,
sub_tasks: VecDeque<BoxFuture<'static, SubTaskOutput>>,
}
impl Task {
fn new(id: TaskId) -> Self {
Task {
id,
sub_tasks: VecDeque::new(),
}
}
fn add_sub_task<T>(&mut self, sub_task: T)
where
T: Future<Output = SubTaskOutput> + Send + 'static,
{
self.sub_tasks.push_back(sub_task.boxed());
}
fn drain_sub_tasks(&mut self) -> VecDeque<BoxFuture<'static, SubTaskOutput>> {
std::mem::take(&mut self.sub_tasks)
}
}
impl fmt::Debug for Task {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Task")
.field("id", &self.id)
.field("sub_tasks len", &self.sub_tasks.len())
.finish()
}
}
#[derive(Debug)]
pub(super) struct TaskQueue {
runnables: Arc<ConcurrentQueue<Runnable>>,
// FIXME good point about using a slab is that it's probably faster than a HashMap
// However since we reuse the vacant entries, we get the same TaskId
// which can harm debugging. If this is not acceptable, I'll switch back to using
// a HashMap.
tasks: Arc<Mutex<Slab<Task>>>,
context_name: Arc<str>,
}
impl TaskQueue {
pub fn new(context_name: Arc<str>) -> Self {
TaskQueue {
runnables: Arc::new(ConcurrentQueue::unbounded()),
tasks: Arc::new(Mutex::new(Slab::new())),
context_name,
}
}
pub fn add<F>(&self, future: F) -> (TaskId, async_task::Task<<F as Future>::Output>)
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let tasks_clone = Arc::clone(&self.tasks);
let mut tasks = self.tasks.lock().unwrap();
let task_id = TaskId(tasks.vacant_entry().key());
let context_name = Arc::clone(&self.context_name);
let task_fut = async move {
gst_trace!(
RUNTIME_CAT,
"Running {:?} on context {}",
task_id,
context_name
);
let _guard = CallOnDrop::new(move || {
if let Some(task) = tasks_clone.lock().unwrap().try_remove(task_id.0) {
if !task.sub_tasks.is_empty() {
gst_warning!(
RUNTIME_CAT,
"Task {:?} on context {} has {} pending sub tasks",
task_id,
context_name,
task.sub_tasks.len(),
);
}
}
gst_trace!(
RUNTIME_CAT,
"Done {:?} on context {}",
task_id,
context_name
);
});
TaskFuture {
id: task_id,
future,
}
.await
};
let runnables = Arc::clone(&self.runnables);
let (runnable, task) = async_task::spawn(task_fut, move |runnable| {
runnables.push(runnable).unwrap();
});
tasks.insert(Task::new(task_id));
drop(tasks);
runnable.schedule();
(task_id, task)
}
/// Adds a task to be blocked on immediately.
///
/// # Safety
///
/// The function and its output must outlive the execution
/// of the resulting task and the retrieval of the result.
pub unsafe fn add_sync<F, O>(&self, f: F) -> async_task::Task<O>
where
F: FnOnce() -> O + Send,
O: Send,
{
let tasks_clone = Arc::clone(&self.tasks);
let mut tasks = self.tasks.lock().unwrap();
let task_id = TaskId(tasks.vacant_entry().key());
let context_name = Arc::clone(&self.context_name);
let task_fut = async move {
gst_trace!(
RUNTIME_CAT,
"Executing sync function on context {} as {:?}",
context_name,
task_id,
);
let _guard = CallOnDrop::new(move || {
let _ = tasks_clone.lock().unwrap().try_remove(task_id.0);
gst_trace!(
RUNTIME_CAT,
"Done executing sync function on context {} as {:?}",
context_name,
task_id,
);
});
f()
};
let runnables = Arc::clone(&self.runnables);
// This is the unsafe call for which the lifetime must hold
// until the the Future is Ready and its Output retrieved.
let (runnable, task) = async_task::spawn_unchecked(task_fut, move |runnable| {
runnables.push(runnable).unwrap();
});
tasks.insert(Task::new(task_id));
drop(tasks);
runnable.schedule();
task
}
pub fn pop_runnable(&self) -> Result<Runnable, concurrent_queue::PopError> {
self.runnables.pop()
}
pub fn has_sub_tasks(&self, task_id: TaskId) -> bool {
self.tasks
.lock()
.unwrap()
.get(task_id.0)
.map(|t| !t.sub_tasks.is_empty())
.unwrap_or(false)
}
pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
where
T: Future<Output = SubTaskOutput> + Send + 'static,
{
match self.tasks.lock().unwrap().get_mut(task_id.0) {
Some(task) => {
gst_trace!(
RUNTIME_CAT,
"Adding subtask to {:?} on context {}",
task_id,
self.context_name
);
task.add_sub_task(sub_task);
Ok(())
}
None => {
gst_trace!(RUNTIME_CAT, "Task was removed in the meantime");
Err(sub_task)
}
}
}
pub fn drain_sub_tasks(
&self,
task_id: TaskId,
) -> impl Future<Output = SubTaskOutput> + Send + 'static {
let sub_tasks = self
.tasks
.lock()
.unwrap()
.get_mut(task_id.0)
.map(|task| (task.drain_sub_tasks(), Arc::clone(&self.context_name)));
async move {
if let Some((mut sub_tasks, context_name)) = sub_tasks {
if !sub_tasks.is_empty() {
gst_log!(
RUNTIME_CAT,
"Scheduling draining {} sub tasks from {:?} on '{}'",
sub_tasks.len(),
task_id,
&context_name,
);
for sub_task in sub_tasks.drain(..) {
sub_task.await?;
}
}
}
Ok(())
}
}
}