ts: sync runtime with latest async-io

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/2287>
This commit is contained in:
François Laignel 2025-06-11 17:30:16 +02:00
parent d0ae6b87b4
commit d156d8950f
6 changed files with 249 additions and 85 deletions

View file

@ -30,7 +30,7 @@ bitflags = "2.6.0"
libc = "0.2"
[target.'cfg(target_os = "windows")'.dependencies]
windows-sys = ">=0.52, <=0.59"
windows-sys = { version = ">=0.52, <=0.59", features = ["Win32_Foundation"] }
[target.'cfg(not(target_os = "android"))'.dependencies]
getifaddrs = "0.1"

View file

@ -26,7 +26,9 @@ use std::{
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, OwnedSocket, RawSocket};
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use rustix::io as rio;
use rustix::net as rn;
use rustix::net::addr::SocketAddrArg;
use crate::runtime::RUNTIME_CAT;
@ -109,8 +111,8 @@ impl<T: AsFd + Send + 'static> Async<T> {
/// This method will put the handle in non-blocking mode and register it in
/// [epoll]/[kqueue]/[event ports]/[IOCP].
///
/// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement
/// `AsRawSocket`.
/// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
/// `AsSocket`.
///
/// [epoll]: https://en.wikipedia.org/wiki/Epoll
/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
@ -118,28 +120,32 @@ impl<T: AsFd + Send + 'static> Async<T> {
/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
pub fn new(io: T) -> io::Result<Async<T>> {
// Put the file descriptor in non-blocking mode.
let fd = io.as_fd();
cfg_if::cfg_if! {
// ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux
// for now, as with the standard library, because it seems to behave
// differently depending on the platform.
// https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d
// https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80
// https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a
if #[cfg(target_os = "linux")] {
rustix::io::ioctl_fionbio(fd, true)?;
} else {
let previous = rustix::fs::fcntl_getfl(fd)?;
let new = previous | rustix::fs::OFlags::NONBLOCK;
if new != previous {
rustix::fs::fcntl_setfl(fd, new)?;
}
}
}
set_nonblocking(io.as_fd())?;
Self::new_nonblocking(io)
}
/// Creates an async I/O handle without setting it to non-blocking mode.
///
/// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP].
///
/// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
/// `AsSocket`.
///
/// [epoll]: https://en.wikipedia.org/wiki/Epoll
/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
/// [event ports]: https://illumos.org/man/port_create
/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
///
/// # Caveats
///
/// The caller should ensure that the handle is set to non-blocking mode or that it is okay if
/// it is not set. If not set to non-blocking mode, I/O operations may block the current thread
/// and cause a deadlock in an asynchronous context.
pub fn new_nonblocking(io: T) -> io::Result<Async<T>> {
// SAFETY: It is impossible to drop the I/O source while it is registered through
// this type.
let registration = unsafe { Registration::new(fd) };
let registration = unsafe { Registration::new(io.as_fd()) };
let source = Reactor::with_mut(|reactor| reactor.insert_io(registration))?;
Ok(Async {
@ -191,28 +197,43 @@ impl<T: AsSocket + Send + 'static> Async<T> {
/// This method will put the handle in non-blocking mode and register it in
/// [epoll]/[kqueue]/[event ports]/[IOCP].
///
/// On Unix systems, the handle must implement `AsRawFd`, while on Windows it must implement
/// `AsRawSocket`.
/// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
/// `AsSocket`.
///
/// [epoll]: https://en.wikipedia.org/wiki/Epoll
/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
/// [event ports]: https://illumos.org/man/port_create
/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
pub fn new(io: T) -> io::Result<Async<T>> {
let borrowed = io.as_socket();
// Put the socket in non-blocking mode.
//
// Safety: We assume `as_raw_socket()` returns a valid fd. When we can
// depend on Rust >= 1.63, where `AsFd` is stabilized, and when
// `TimerFd` implements it, we can remove this unsafe and simplify this.
rustix::io::ioctl_fionbio(borrowed, true)?;
set_nonblocking(io.as_socket())?;
Self::new_nonblocking(io)
}
/// Creates an async I/O handle without setting it to non-blocking mode.
///
/// This method will register the handle in [epoll]/[kqueue]/[event ports]/[IOCP].
///
/// On Unix systems, the handle must implement `AsFd`, while on Windows it must implement
/// `AsSocket`.
///
/// [epoll]: https://en.wikipedia.org/wiki/Epoll
/// [kqueue]: https://en.wikipedia.org/wiki/Kqueue
/// [event ports]: https://illumos.org/man/port_create
/// [IOCP]: https://learn.microsoft.com/en-us/windows/win32/fileio/i-o-completion-ports
///
/// # Caveats
///
/// The caller should ensure that the handle is set to non-blocking mode or that it is okay if
/// it is not set. If not set to non-blocking mode, I/O operations may block the current thread
/// and cause a deadlock in an asynchronous context.
pub fn new_nonblocking(io: T) -> io::Result<Async<T>> {
// Create the registration.
//
// SAFETY: It is impossible to drop the I/O source while it is registered through
// this type.
let registration = unsafe { Registration::new(borrowed) };
let registration = unsafe { Registration::new(io.as_socket()) };
let source = Reactor::with_mut(|reactor| reactor.insert_io(registration))?;
Ok(Async {
@ -264,7 +285,11 @@ impl<T: Send + 'static> Async<T> {
}
/// Gets a mutable reference to the inner I/O handle.
pub fn get_mut(&mut self) -> &mut T {
///
/// # Safety
///
/// The underlying I/O source must not be dropped using this function.
pub unsafe fn get_mut(&mut self) -> &mut T {
self.io.as_mut().unwrap()
}
@ -362,7 +387,11 @@ impl<T: Send + 'static> Async<T> {
/// sends a notification that the I/O handle is readable.
///
/// The closure receives a mutable reference to the I/O handle.
pub async fn read_with_mut<R>(
///
/// # Safety
///
/// In the closure, the underlying I/O source must not be dropped.
pub async unsafe fn read_with_mut<R>(
&mut self,
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
@ -403,7 +432,12 @@ impl<T: Send + 'static> Async<T> {
/// sends a notification that the I/O handle is writable.
///
/// The closure receives a mutable reference to the I/O handle.
pub async fn write_with_mut<R>(
///
/// # Safety
///
/// The closure receives a mutable reference to the I/O handle. In the closure, the underlying
/// I/O source must not be dropped.
pub async unsafe fn write_with_mut<R>(
&mut self,
op: impl FnMut(&mut T) -> io::Result<R>,
) -> io::Result<R> {
@ -424,12 +458,6 @@ impl<T: Send + 'static> AsRef<T> for Async<T> {
}
}
impl<T: Send + 'static> AsMut<T> for Async<T> {
fn as_mut(&mut self) -> &mut T {
self.get_mut()
}
}
impl<T: Send + 'static> Drop for Async<T> {
fn drop(&mut self) {
if let Some(io) = self.io.take() {
@ -543,14 +571,14 @@ unsafe impl<T: IoSafe + ?Sized> IoSafe for &mut T {}
unsafe impl<T: IoSafe + ?Sized> IoSafe for Box<T> {}
unsafe impl<T: Clone + IoSafe> IoSafe for std::borrow::Cow<'_, T> {}
impl<T: Read + Send + 'static> AsyncRead for Async<T> {
impl<T: IoSafe + Read + Send + 'static> AsyncRead for Async<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_mut().read(buf) {
match unsafe { (*self).get_mut() }.read(buf) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
@ -564,7 +592,7 @@ impl<T: Read + Send + 'static> AsyncRead for Async<T> {
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_mut().read_vectored(bufs) {
match unsafe { (*self).get_mut() }.read_vectored(bufs) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
@ -608,14 +636,14 @@ where
}
}
impl<T: Write + Send + 'static> AsyncWrite for Async<T> {
impl<T: IoSafe + Write + Send + 'static> AsyncWrite for Async<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_mut().write(buf) {
match unsafe { (*self).get_mut() }.write(buf) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
@ -629,7 +657,7 @@ impl<T: Write + Send + 'static> AsyncWrite for Async<T> {
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
loop {
match (*self).get_mut().write_vectored(bufs) {
match unsafe { (*self).get_mut() }.write_vectored(bufs) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
@ -639,7 +667,7 @@ impl<T: Write + Send + 'static> AsyncWrite for Async<T> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match (*self).get_mut().flush() {
match unsafe { (*self).get_mut() }.flush() {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
res => return Poll::Ready(res),
}
@ -739,11 +767,17 @@ impl TryFrom<std::net::TcpListener> for Async<std::net::TcpListener> {
impl Async<TcpStream> {
/// Creates a TCP connection to the specified address.
pub async fn connect<A: Into<SocketAddr>>(addr: A) -> io::Result<Async<TcpStream>> {
// Begin async connect.
// Figure out how to handle this address.
let addr = addr.into();
let domain = Domain::for_address(addr);
let socket = connect(addr.into(), domain, Some(Protocol::TCP))?;
let stream = Async::new(TcpStream::from(socket))?;
let (domain, sock_addr) = match addr {
SocketAddr::V4(v4) => (rn::AddressFamily::INET, v4.as_any()),
SocketAddr::V6(v6) => (rn::AddressFamily::INET6, v6.as_any()),
};
// Begin async connect.
let socket = connect(sock_addr, domain, Some(rn::ipproto::TCP))?;
// Use new_nonblocking because connect already sets socket to non-blocking mode.
let stream = Async::new_nonblocking(TcpStream::from(socket))?;
// The stream becomes writable when connected.
stream.writable().await?;
@ -900,9 +934,12 @@ impl TryFrom<std::os::unix::net::UnixListener> for Async<std::os::unix::net::Uni
impl Async<UnixStream> {
/// Creates a UDS stream connected to the specified path.
pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<Async<UnixStream>> {
let address = convert_path_to_socket_address(path.as_ref())?;
// Begin async connect.
let socket = connect(SockAddr::unix(path)?, Domain::UNIX, None)?;
let stream = Async::new(UnixStream::from(socket))?;
let socket = connect(address.into(), rn::AddressFamily::UNIX, None)?;
// Use new_nonblocking because connect already sets socket to non-blocking mode.
let stream = Async::new_nonblocking(UnixStream::from(socket))?;
// The stream becomes writable when connected.
stream.writable().await?;
@ -1008,8 +1045,16 @@ async fn optimistic(fut: impl Future<Output = io::Result<()>>) -> io::Result<()>
.await
}
fn connect(addr: SockAddr, domain: Domain, protocol: Option<Protocol>) -> io::Result<Socket> {
let sock_type = Type::STREAM;
fn connect(
addr: rn::SocketAddrAny,
domain: rn::AddressFamily,
protocol: Option<rn::Protocol>,
) -> io::Result<rustix::fd::OwnedFd> {
#[cfg(windows)]
use rustix::fd::AsFd;
setup_networking();
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
@ -1020,10 +1065,13 @@ fn connect(addr: SockAddr, domain: Domain, protocol: Option<Protocol>) -> io::Re
target_os = "netbsd",
target_os = "openbsd"
))]
// If we can, set nonblocking at socket creation for unix
let sock_type = sock_type.nonblocking();
// This automatically handles cloexec on unix, no_inherit on windows and nosigpipe on macos
let socket = Socket::new(domain, sock_type, protocol)?;
let socket = rn::socket_with(
domain,
rn::SocketType::STREAM,
rn::SocketFlags::CLOEXEC | rn::SocketFlags::NONBLOCK,
protocol,
)?;
#[cfg(not(any(
target_os = "android",
target_os = "dragonfly",
@ -1034,14 +1082,132 @@ fn connect(addr: SockAddr, domain: Domain, protocol: Option<Protocol>) -> io::Re
target_os = "netbsd",
target_os = "openbsd"
)))]
// If the current platform doesn't support nonblocking at creation, enable it after creation
socket.set_nonblocking(true)?;
match socket.connect(&addr) {
let socket = {
#[cfg(not(any(
target_os = "aix",
target_vendor = "apple",
target_os = "espidf",
windows,
)))]
let flags = rn::SocketFlags::CLOEXEC;
#[cfg(any(
target_os = "aix",
target_vendor = "apple",
target_os = "espidf",
windows,
))]
let flags = rn::SocketFlags::empty();
// Create the socket.
let socket = rn::socket_with(domain, rn::SocketType::STREAM, flags, protocol)?;
// Set cloexec if necessary.
#[cfg(any(target_os = "aix", target_vendor = "apple"))]
rio::fcntl_setfd(&socket, rio::fcntl_getfd(&socket)? | rio::FdFlags::CLOEXEC)?;
// Set non-blocking mode.
set_nonblocking(socket.as_fd())?;
socket
};
// Set nosigpipe if necessary.
#[cfg(any(
target_vendor = "apple",
target_os = "freebsd",
target_os = "netbsd",
target_os = "dragonfly",
))]
rn::sockopt::set_socket_nosigpipe(&socket, true)?;
// Set the handle information to HANDLE_FLAG_INHERIT.
#[cfg(windows)]
unsafe {
if windows_sys::Win32::Foundation::SetHandleInformation(
socket.as_raw_socket() as _,
windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
windows_sys::Win32::Foundation::HANDLE_FLAG_INHERIT,
) == 0
{
return Err(io::Error::last_os_error());
}
}
#[allow(unreachable_patterns)]
match rn::connect(&socket, &addr) {
Ok(_) => {}
#[cfg(unix)]
Err(err) if err.raw_os_error() == Some(rustix::io::Errno::INPROGRESS.raw_os_error()) => {}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
Err(err) => return Err(err),
Err(rio::Errno::INPROGRESS) => {}
Err(rio::Errno::AGAIN) | Err(rio::Errno::WOULDBLOCK) => {}
Err(err) => return Err(err.into()),
}
Ok(socket)
}
#[inline]
fn setup_networking() {
#[cfg(windows)]
{
// On Windows, we need to call WSAStartup before calling any networking code.
// Make sure to call it at least once.
static INIT: std::sync::Once = std::sync::Once::new();
INIT.call_once(|| {
let _ = rustix::net::wsa_startup();
});
}
}
#[inline]
fn set_nonblocking(
#[cfg(unix)] fd: BorrowedFd<'_>,
#[cfg(windows)] fd: BorrowedSocket<'_>,
) -> io::Result<()> {
cfg_if::cfg_if! {
// ioctl(FIONBIO) sets the flag atomically, but we use this only on Linux
// for now, as with the standard library, because it seems to behave
// differently depending on the platform.
// https://github.com/rust-lang/rust/commit/efeb42be2837842d1beb47b51bb693c7474aba3d
// https://github.com/libuv/libuv/blob/e9d91fccfc3e5ff772d5da90e1c4a24061198ca0/src/unix/poll.c#L78-L80
// https://github.com/tokio-rs/mio/commit/0db49f6d5caf54b12176821363d154384357e70a
if #[cfg(any(windows, target_os = "linux"))] {
rustix::io::ioctl_fionbio(fd, true)?;
} else {
let previous = rustix::fs::fcntl_getfl(fd)?;
let new = previous | rustix::fs::OFlags::NONBLOCK;
if new != previous {
rustix::fs::fcntl_setfl(fd, new)?;
}
}
}
Ok(())
}
/// Converts a `Path` to its socket address representation.
///
/// This function is abstract socket-aware.
#[cfg(unix)]
#[inline]
fn convert_path_to_socket_address(path: &Path) -> io::Result<rn::SocketAddrUnix> {
// SocketAddrUnix::new() will throw EINVAL when a path with a zero in it is passed in.
// However, some users expect to be able to pass in paths to abstract sockets, which
// triggers this error as it has a zero in it. Therefore, if a path starts with a zero,
// make it an abstract socket.
#[cfg(any(target_os = "linux", target_os = "android"))]
let address = {
use std::os::unix::ffi::OsStrExt;
let path = path.as_os_str();
match path.as_bytes().first() {
Some(0) => rn::SocketAddrUnix::new_abstract_name(path.as_bytes().get(1..).unwrap())?,
_ => rn::SocketAddrUnix::new(path)?,
}
};
// Only Linux and Android support abstract sockets.
#[cfg(not(any(target_os = "linux", target_os = "android")))]
let address = rn::SocketAddrUnix::new(path)?;
Ok(address)
}

View file

@ -32,10 +32,7 @@ cfg_if::cfg_if! {
mod windows;
pub use windows::Registration;
} else if #[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "tvos",
target_os = "watchos",
target_vendor = "apple",
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
@ -420,7 +417,7 @@ impl Reactor {
if let Some(source) = self.sources.get(ev.key) {
let mut state = source.state.lock().unwrap();
// Collect wakers if a writability event was emitted.
// Collect wakers if any event was emitted.
for &(dir, emitted) in &[(WRITE, ev.writable), (READ, ev.readable)] {
if emitted {
state[dir].tick = tick;

View file

@ -5,9 +5,11 @@ use polling::{Event, Poller};
use std::fmt;
use std::io::Result;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::os::unix::io::{AsRawFd, BorrowedFd, RawFd};
/// The raw registration into the reactor.
///
/// This needs to be public, since it is technically exposed through the `QueueableSealed` trait.
#[doc(hidden)]
pub struct Registration {
/// Raw file descriptor for readability/writability.
@ -32,10 +34,8 @@ impl Registration {
/// # Safety
///
/// The provided file descriptor must be valid and not be closed while this object is alive.
pub(crate) unsafe fn new(f: impl AsFd) -> Self {
Self {
raw: f.as_fd().as_raw_fd(),
}
pub(crate) unsafe fn new(f: BorrowedFd<'_>) -> Self {
Self::Fd(f.as_raw_fd())
}
/// Registers the object into the reactor.

View file

@ -5,7 +5,7 @@ use polling::{Event, Poller};
use std::fmt;
use std::io::Result;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::os::unix::io::{AsRawFd, BorrowedFd, RawFd};
/// The raw registration into the reactor.
#[doc(hidden)]
@ -31,10 +31,8 @@ impl Registration {
/// # Safety
///
/// The provided file descriptor must be valid and not be closed while this object is alive.
pub(crate) unsafe fn new(f: impl AsFd) -> Self {
Self {
raw: f.as_fd().as_raw_fd(),
}
pub(crate) unsafe fn new(f: BorrowedFd<'_>) -> Self {
Self { raw: f.as_raw_fd() }
}
/// Registers the object into the reactor.

View file

@ -4,7 +4,7 @@
use polling::{Event, Poller};
use std::fmt;
use std::io::Result;
use std::os::windows::io::{AsRawSocket, AsSocket, BorrowedSocket, RawSocket};
use std::os::windows::io::{AsRawSocket, BorrowedSocket, RawSocket};
/// The raw registration into the reactor.
#[doc(hidden)]
@ -18,6 +18,9 @@ pub struct Registration {
raw: RawSocket,
}
unsafe impl Send for Registration {}
unsafe impl Sync for Registration {}
impl fmt::Debug for Registration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.raw, f)
@ -30,9 +33,9 @@ impl Registration {
/// # Safety
///
/// The provided file descriptor must be valid and not be closed while this object is alive.
pub(crate) unsafe fn new(f: impl AsSocket) -> Self {
pub(crate) unsafe fn new(f: BorrowedSocket<'_>) -> Self {
Self {
raw: f.as_socket().as_raw_socket(),
raw: f.as_raw_socket(),
}
}