// Copyright (C) 2018-2020 Sebastian Dröge // Copyright (C) 2019-2022 François Laignel // // Take a look at the license at the top of the repository in the LICENSE file. use async_task::Runnable; use concurrent_queue::ConcurrentQueue; use futures::future::BoxFuture; use futures::prelude::*; use pin_project_lite::pin_project; use slab::Slab; use std::cell::Cell; use std::collections::VecDeque; use std::fmt; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::Poll; use super::CallOnDrop; use crate::runtime::RUNTIME_CAT; thread_local! { static CURRENT_TASK_ID: Cell> = const { Cell::new(None) }; } #[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] pub struct TaskId(pub(super) usize); impl TaskId { pub(super) fn current() -> Option { CURRENT_TASK_ID.try_with(Cell::get).ok().flatten() } } pub type SubTaskOutput = Result<(), gst::FlowError>; pin_project! { pub(super) struct TaskFuture { id: TaskId, #[pin] future: F, } } impl Future for TaskFuture { type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { struct TaskIdGuard { prev_task_id: Option, } impl Drop for TaskIdGuard { fn drop(&mut self) { let _ = CURRENT_TASK_ID.try_with(|cur| cur.replace(self.prev_task_id.take())); } } let task_id = self.id; let project = self.project(); let _guard = TaskIdGuard { prev_task_id: CURRENT_TASK_ID.with(|cur| cur.replace(Some(task_id))), }; project.future.poll(cx) } } struct Task { id: TaskId, sub_tasks: VecDeque>, } impl Task { fn new(id: TaskId) -> Self { Task { id, sub_tasks: VecDeque::new(), } } fn add_sub_task(&mut self, sub_task: T) where T: Future + Send + 'static, { self.sub_tasks.push_back(sub_task.boxed()); } } impl fmt::Debug for Task { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Task") .field("id", &self.id) .field("sub_tasks len", &self.sub_tasks.len()) .finish() } } #[derive(Debug)] pub(super) struct TaskQueue { runnables: Arc>, // FIXME good point about using a slab is that it's probably faster than a HashMap // However since we reuse the vacant entries, we get the same TaskId // which can harm debugging. If this is not acceptable, I'll switch back to using // a HashMap. tasks: Arc>>, context_name: Arc, } impl TaskQueue { pub fn new(context_name: Arc) -> Self { TaskQueue { runnables: Arc::new(ConcurrentQueue::unbounded()), tasks: Arc::new(Mutex::new(Slab::new())), context_name, } } pub fn add(&self, future: F) -> (TaskId, async_task::Task<::Output>) where F: Future + Send + 'static, F::Output: Send + 'static, { let tasks_clone = Arc::clone(&self.tasks); let mut tasks = self.tasks.lock().unwrap(); let task_id = TaskId(tasks.vacant_entry().key()); let context_name = Arc::clone(&self.context_name); let task_fut = async move { gst::trace!( RUNTIME_CAT, "Running {:?} on context {}", task_id, context_name ); let _guard = CallOnDrop::new(move || { if let Some(task) = tasks_clone.lock().unwrap().try_remove(task_id.0) { if !task.sub_tasks.is_empty() { gst::warning!( RUNTIME_CAT, "Task {:?} on context {} has {} pending sub tasks", task_id, context_name, task.sub_tasks.len(), ); } } gst::trace!( RUNTIME_CAT, "Done {:?} on context {}", task_id, context_name ); }); TaskFuture { id: task_id, future, } .await }; let runnables = Arc::clone(&self.runnables); let (runnable, task) = async_task::spawn(task_fut, move |runnable| { runnables.push(runnable).unwrap(); }); tasks.insert(Task::new(task_id)); drop(tasks); runnable.schedule(); (task_id, task) } /// Adds a task to be blocked on immediately. /// /// # Safety /// /// The function and its output must outlive the execution /// of the resulting task and the retrieval of the result. pub unsafe fn add_sync(&self, f: F) -> async_task::Task where F: FnOnce() -> O + Send, O: Send, { let tasks_clone = Arc::clone(&self.tasks); let mut tasks = self.tasks.lock().unwrap(); let task_id = TaskId(tasks.vacant_entry().key()); let context_name = Arc::clone(&self.context_name); let task_fut = async move { gst::trace!( RUNTIME_CAT, "Executing sync function on context {} as {:?}", context_name, task_id, ); let _guard = CallOnDrop::new(move || { let _ = tasks_clone.lock().unwrap().try_remove(task_id.0); gst::trace!( RUNTIME_CAT, "Done executing sync function on context {} as {:?}", context_name, task_id, ); }); f() }; let runnables = Arc::clone(&self.runnables); // This is the unsafe call for which the lifetime must hold // until the the Future is Ready and its Output retrieved. let (runnable, task) = async_task::spawn_unchecked(task_fut, move |runnable| { runnables.push(runnable).unwrap(); }); tasks.insert(Task::new(task_id)); drop(tasks); runnable.schedule(); task } pub fn pop_runnable(&self) -> Result { self.runnables.pop() } pub fn add_sub_task(&self, task_id: TaskId, sub_task: T) -> Result<(), T> where T: Future + Send + 'static, { let mut state = self.tasks.lock().unwrap(); match state.get_mut(task_id.0) { Some(task) => { gst::trace!( RUNTIME_CAT, "Adding subtask to {:?} on context {}", task_id, self.context_name ); task.add_sub_task(sub_task); Ok(()) } None => { gst::trace!(RUNTIME_CAT, "Task was removed in the meantime"); Err(sub_task) } } } pub async fn drain_sub_tasks(&self, task_id: TaskId) -> SubTaskOutput { loop { let mut sub_tasks = match self.tasks.lock().unwrap().get_mut(task_id.0) { Some(task) if !task.sub_tasks.is_empty() => std::mem::take(&mut task.sub_tasks), _ => return Ok(()), }; gst::trace!( RUNTIME_CAT, "Scheduling draining {} sub tasks from {:?} on '{}'", sub_tasks.len(), task_id, self.context_name, ); for sub_task in sub_tasks.drain(..) { sub_task.await?; } } } }