webrtc: Move from async-std to tokio

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1019>
This commit is contained in:
Zhao, Gang 2022-12-12 21:36:11 +08:00 committed by GStreamer Marge Bot
parent 2bc29c1fd3
commit 1ffeb4d44d
8 changed files with 132 additions and 95 deletions

View file

@ -20,9 +20,10 @@ once_cell = "1.0"
anyhow = "1"
thiserror = "1"
futures = "0.3"
async-std = { version = "1", features = ["unstable"] }
async-native-tls = { version = "0.4.0" }
async-tungstenite = { version = "0.19", features = ["async-std-runtime", "async-native-tls"] }
tokio = { version = "1", features = ["fs", "macros", "rt-multi-thread", "time"] }
tokio-native-tls = "0.3.0"
tokio-stream = "0.1.11"
async-tungstenite = { version = "0.19", features = ["tokio-runtime", "tokio-native-tls"] }
serde = "1"
serde_json = "1"
fastrand = "1.0"

View file

@ -3,8 +3,9 @@ use std::sync::{Arc, Mutex};
use anyhow::Error;
use async_std::net::{TcpListener, TcpStream};
use async_std::task;
use tokio::net::{TcpListener, TcpStream};
use tokio::task;
use tokio::time;
use async_tungstenite::tungstenite::Message as WsMessage;
use clap::Parser;
use futures::channel::mpsc;
@ -150,9 +151,10 @@ async fn run(args: Args) -> Result<(), Error> {
let ws_clone = ws.downgrade();
let state_clone = state.clone();
task::spawn(async move {
let mut interval = async_std::stream::interval(std::time::Duration::from_millis(100));
let mut interval = time::interval(std::time::Duration::from_millis(100));
while interval.next().await.is_some() {
loop {
interval.tick().await;
if let Some(ws) = ws_clone.upgrade() {
let stats = ws.property::<gst::Structure>("stats");
let stats = serialize_value(&stats.to_value()).unwrap();
@ -193,7 +195,7 @@ async fn accept_connection(state: Arc<Mutex<State>>, stream: TcpStream) {
.expect("connected streams should have a peer address");
info!("Peer address: {}", addr);
let mut ws_stream = async_tungstenite::accept_async(stream)
let mut ws_stream = async_tungstenite::tokio::accept_async(stream)
.await
.expect("Error during the websocket handshake occurred");
@ -218,10 +220,11 @@ async fn accept_connection(state: Arc<Mutex<State>>, stream: TcpStream) {
});
}
fn main() -> Result<(), Error> {
#[tokio::main]
async fn main() -> Result<(), Error> {
gst::init()?;
let args = Args::parse();
task::block_on(run(args))
run(args).await
}

View file

@ -9,10 +9,11 @@ repository = "https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs"
rust-version = "1.63"
[dependencies]
once_cell = "1.0"
anyhow = "1"
async-std = { version = "1", features = ["unstable", "attributes"] }
async-native-tls = "0.4"
async-tungstenite = { version = "0.19", features = ["async-std-runtime", "async-native-tls"] }
tokio = { version = "1", features = ["fs", "io-util", "macros", "rt-multi-thread", "time"] }
tokio-native-tls = "0.3.0"
async-tungstenite = { version = "0.19", features = ["tokio-runtime", "tokio-native-tls"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
clap = { version = "4", features = ["derive"] }

View file

@ -1,15 +1,16 @@
// SPDX-License-Identifier: MPL-2.0
use async_std::task;
use tokio::io::AsyncReadExt;
use tokio::task;
use clap::Parser;
use gst_plugin_webrtc_signalling::handlers::Handler;
use gst_plugin_webrtc_signalling::server::Server;
use tracing_subscriber::prelude::*;
use anyhow::Error;
use async_native_tls::TlsAcceptor;
use async_std::fs::File as AsyncFile;
use async_std::net::TcpListener;
use tokio_native_tls::native_tls::TlsAcceptor;
use tokio::fs;
use tokio::net::TcpListener;
use tracing::{info, warn};
#[derive(Parser, Debug)]
@ -49,55 +50,57 @@ fn initialize_logging(envvar_name: &str) -> Result<(), Error> {
Ok(())
}
fn main() -> Result<(), Error> {
#[tokio::main]
async fn main() -> Result<(), Error> {
let args = Args::parse();
let server = Server::spawn(|stream| Handler::new(stream));
initialize_logging("WEBRTCSINK_SIGNALLING_SERVER_LOG")?;
task::block_on(async move {
let addr = format!("{}:{}", args.host, args.port);
let addr = format!("{}:{}", args.host, args.port);
// Create the event loop and TCP listener we'll accept connections on.
let listener = TcpListener::bind(&addr).await?;
// Create the event loop and TCP listener we'll accept connections on.
let listener = TcpListener::bind(&addr).await?;
let acceptor = match args.cert {
Some(cert) => {
let key = AsyncFile::open(cert).await?;
Some(TlsAcceptor::new(key, args.cert_password.as_deref().unwrap_or("")).await?)
let acceptor = match args.cert {
Some(cert) => {
let mut file = fs::File::open(cert).await?;
let mut identity = vec![];
file.read_to_end(&mut identity).await?;
let identity = tokio_native_tls::native_tls::Identity::from_pkcs12(&identity, args.cert_password.as_deref().unwrap_or("")).unwrap();
Some(tokio_native_tls::TlsAcceptor::from(TlsAcceptor::new(identity).unwrap()))
}
None => None,
};
info!("Listening on: {}", addr);
while let Ok((stream, _)) = listener.accept().await {
let mut server_clone = server.clone();
let address = match stream.peer_addr() {
Ok(address) => address,
Err(err) => {
warn!("Connected peer with no address: {}", err);
continue;
}
None => None,
};
info!("Listening on: {}", addr);
info!("Accepting connection from {}", address);
while let Ok((stream, _)) = listener.accept().await {
let mut server_clone = server.clone();
let address = match stream.peer_addr() {
Ok(address) => address,
if let Some(ref acceptor) = acceptor {
let stream = match acceptor.accept(stream).await {
Ok(stream) => stream,
Err(err) => {
warn!("Connected peer with no address: {}", err);
warn!("Failed to accept TLS connection from {}: {}", address, err);
continue;
}
};
info!("Accepting connection from {}", address);
if let Some(ref acceptor) = acceptor {
let stream = match acceptor.accept(stream).await {
Ok(stream) => stream,
Err(err) => {
warn!("Failed to accept TLS connection from {}: {}", address, err);
continue;
}
};
task::spawn(async move { server_clone.accept_async(stream).await });
} else {
task::spawn(async move { server_clone.accept_async(stream).await });
}
task::spawn(async move { server_clone.accept_async(stream).await });
} else {
task::spawn(async move { server_clone.accept_async(stream).await });
}
}
Ok(())
})
Ok(())
}

View file

@ -381,7 +381,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_register_producer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -400,7 +400,7 @@ mod tests {
.unwrap();
}
#[async_std::test]
#[tokio::test]
async fn test_list_producers() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -438,7 +438,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_welcome() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -446,7 +446,7 @@ mod tests {
new_peer(&mut tx, &mut handler, "consumer").await;
}
#[async_std::test]
#[tokio::test]
async fn test_listener() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -491,7 +491,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_start_session() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -540,7 +540,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_remove_peer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -622,7 +622,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_end_session_consumer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -677,7 +677,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_disconnect_consumer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -726,7 +726,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_end_session_producer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -780,7 +780,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_end_session_twice() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -854,7 +854,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_sdp_exchange() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -916,7 +916,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_ice_exchange() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -1004,7 +1004,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_sdp_exchange_wrong_direction_offer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -1064,7 +1064,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_start_session_no_producer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -1088,7 +1088,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_stop_producing() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -1156,7 +1156,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_unregistering_with_listeners() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -1258,7 +1258,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_start_session_no_consumer() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -1290,7 +1290,7 @@ mod tests {
);
}
#[async_std::test]
#[tokio::test]
async fn test_start_session_twice() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));
@ -1351,7 +1351,7 @@ mod tests {
assert_ne!(session0_id, session1_id);
}
#[async_std::test]
#[tokio::test]
async fn test_start_session_stop_producing() {
let (mut tx, rx) = mpsc::unbounded();
let mut handler = Handler::new(Box::pin(rx));

View file

@ -1,11 +1,11 @@
// SPDX-License-Identifier: MPL-2.0
use anyhow::Error;
use async_std::task;
use tokio::task;
use async_tungstenite::tungstenite::Message as WsMessage;
use futures::channel::mpsc;
use futures::prelude::*;
use futures::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::pin::Pin;
@ -99,7 +99,10 @@ impl Server {
if let Err(err) = peer.send_task_handle.await {
trace!(peer_id = %peer_id, "Error while joining send task: {}", err);
}
peer.receive_task_handle.await;
if let Err(err) = peer.receive_task_handle.await {
trace!(peer_id = %peer_id, "Error while joining receive task: {}", err);
}
});
}
}
@ -109,7 +112,7 @@ impl Server {
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
let ws = match async_tungstenite::accept_async(stream).await {
let ws = match async_tungstenite::tokio::accept_async(stream).await {
Ok(ws) => ws,
Err(err) => {
warn!("Error during the websocket handshake: {}", err);
@ -128,7 +131,7 @@ impl Server {
let (mut ws_sink, mut ws_stream) = ws.split();
let send_task_handle = task::spawn(async move {
loop {
match async_std::future::timeout(
match tokio::time::timeout(
std::time::Duration::from_secs(30),
websocket_receiver.next(),
)

View file

@ -2,7 +2,8 @@
use crate::webrtcsink::WebRTCSink;
use anyhow::{anyhow, Error};
use async_std::task;
use tokio::runtime;
use tokio::task;
use async_tungstenite::tungstenite::Message as WsMessage;
use futures::channel::mpsc;
use futures::prelude::*;
@ -24,6 +25,14 @@ static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
)
});
static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(1)
.build()
.unwrap()
});
#[derive(Default)]
struct State {
/// Sender for the websocket messages
@ -58,15 +67,16 @@ impl Signaller {
let settings = self.settings.lock().unwrap().clone();
let connector = if let Some(path) = settings.cafile {
let cert = async_std::fs::read_to_string(&path).await?;
let cert = async_native_tls::Certificate::from_pem(cert.as_bytes())?;
let connector = async_native_tls::TlsConnector::new();
Some(connector.add_root_certificate(cert))
let cert = tokio::fs::read_to_string(&path).await?;
let cert = tokio_native_tls::native_tls::Certificate::from_pem(cert.as_bytes())?;
let mut connector_builder = tokio_native_tls::native_tls::TlsConnector::builder();
let connector = connector_builder.add_root_certificate(cert).build()?;
Some(tokio_native_tls::TlsConnector::from(connector))
} else {
None
};
let (ws, _) = async_tungstenite::async_std::connect_async_with_tls_connector(
let (ws, _) = async_tungstenite::tokio::connect_async_with_tls_connector(
settings.address.unwrap(),
connector,
)
@ -117,7 +127,7 @@ impl Signaller {
let element_clone = element.downgrade();
let receive_task_handle = task::spawn(async move {
while let Some(msg) = async_std::stream::StreamExt::next(&mut ws_stream).await {
while let Some(msg) = tokio_stream::StreamExt::next(&mut ws_stream).await {
if let Some(element) = element_clone.upgrade() {
match msg {
Ok(WsMessage::Text(msg)) => {
@ -275,7 +285,7 @@ impl Signaller {
if let Some(mut sender) = state.websocket_sender.clone() {
let element = element.downgrade();
task::spawn(async move {
RUNTIME.spawn(async move {
if let Err(err) = sender.send(msg).await {
if let Some(element) = element.upgrade() {
element.handle_signalling_error(anyhow!("Error: {}", err).into());
@ -305,7 +315,7 @@ impl Signaller {
if let Some(mut sender) = state.websocket_sender.clone() {
let element = element.downgrade();
task::spawn(async move {
RUNTIME.spawn(async move {
if let Err(err) = sender.send(msg).await {
if let Some(element) = element.upgrade() {
element.handle_signalling_error(anyhow!("Error: {}", err).into());
@ -322,17 +332,24 @@ impl Signaller {
let send_task_handle = state.send_task_handle.take();
let receive_task_handle = state.receive_task_handle.take();
if let Some(mut sender) = state.websocket_sender.take() {
task::block_on(async move {
let element = element.downgrade();
RUNTIME.block_on(async move {
sender.close_channel();
if let Some(handle) = send_task_handle {
if let Err(err) = handle.await {
gst::warning!(CAT, obj: element, "Error while joining send task: {}", err);
if let Some(element) = element.upgrade() {
gst::warning!(CAT, obj: element, "Error while joining send task: {}", err);
}
}
}
if let Some(handle) = receive_task_handle {
handle.await;
if let Err(err) = handle.await {
if let Some(element) = element.upgrade() {
gst::warning!(CAT, obj: element, "Error while joining receive task: {}", err);
}
}
}
});
}
@ -345,7 +362,7 @@ impl Signaller {
let session_id = session_id.to_string();
let element = element.downgrade();
if let Some(mut sender) = state.websocket_sender.clone() {
task::spawn(async move {
RUNTIME.spawn(async move {
if let Err(err) = sender
.send(p::IncomingMessage::EndSession(p::EndSessionMessage {
session_id: session_id.to_string(),

View file

@ -9,7 +9,7 @@ use gst_utils::StreamProducer;
use gst_video::subclass::prelude::*;
use gst_webrtc::WebRTCDataChannel;
use async_std::task;
use tokio::runtime;
use futures::prelude::*;
use anyhow::{anyhow, Error};
@ -31,6 +31,14 @@ static CAT: Lazy<gst::DebugCategory> = Lazy::new(|| {
)
});
static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(1)
.build()
.unwrap()
});
const CUDA_MEMORY_FEATURE: &str = "memory:CUDAMemory";
const GL_MEMORY_FEATURE: &str = "memory:GLMemory";
const NVMM_MEMORY_FEATURE: &str = "memory:NVMM";
@ -1226,7 +1234,7 @@ impl WebRTCSink {
}
if let Some(receiver) = state.codecs_done_receiver.take() {
task::block_on(async {
RUNTIME.spawn(async {
let _ = receiver.await;
});
}
@ -1678,7 +1686,7 @@ impl WebRTCSink {
let pipeline_clone = pipeline.downgrade();
let session_id_clone = session_id.to_owned();
task::spawn(async move {
RUNTIME.spawn(async move {
while let Some(msg) = bus_stream.next().await {
if let Some(element) = element_clone.upgrade() {
let this = element.imp();
@ -1931,11 +1939,12 @@ impl WebRTCSink {
let element_clone = element.downgrade();
let webrtcbin = session.webrtcbin.downgrade();
task::spawn(async move {
RUNTIME.spawn(async move {
let mut interval =
async_std::stream::interval(std::time::Duration::from_millis(100));
tokio::time::interval(std::time::Duration::from_millis(100));
while interval.next().await.is_some() {
loop {
interval.tick().await;
let element_clone = element_clone.clone();
if let (Some(webrtcbin), Some(element)) =
(webrtcbin.upgrade(), element_clone.upgrade())
@ -2294,7 +2303,7 @@ impl WebRTCSink {
if all_pads_have_caps {
let element_clone = element.downgrade();
task::spawn(async move {
RUNTIME.spawn(async move {
if let Some(element) = element_clone.upgrade() {
let this = element.imp();
let (fut, handle) =