From 09cb2a53b029ef2c60e85d67b9facf9a5853770d Mon Sep 17 00:00:00 2001 From: "Aode (Lion)" Date: Wed, 13 Oct 2021 19:06:53 -0500 Subject: [PATCH] Rewrite to avoid direct AsyncX impls --- Cargo.lock | 26 +- Cargo.toml | 2 +- src/either.rs | 43 +++ src/file.rs | 668 +++++++++++++++++++----------------------- src/magick.rs | 6 +- src/main.rs | 21 +- src/range.rs | 33 +-- src/stream.rs | 39 +-- src/upload_manager.rs | 52 +--- 9 files changed, 399 insertions(+), 491 deletions(-) create mode 100644 src/either.rs diff --git a/Cargo.lock b/Cargo.lock index d37f00f..d0395c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1282,9 +1282,9 @@ checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" [[package]] name = "proc-macro2" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9f5105d4fdaab20335ca9565e106a5d9b82b6219b5ba735731124ac6711d23d" +checksum = "edc3358ebc67bc8b7fa0c007f945b0b18226f78437d61bec735a9eb96b61ee70" dependencies = [ "unicode-xid", ] @@ -1448,6 +1448,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "rio" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e98c25665909853c07874301124482754434520ab572ac6a22e90366de6685b" +dependencies = [ + "libc", +] + [[package]] name = "rustc_version" version = "0.2.3" @@ -1617,9 +1626,9 @@ dependencies = [ [[package]] name = "sharded-slab" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "740223c51853f3145fe7c90360d2d4232f2b62e3449489c207eccde818979982" +checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" dependencies = [ "lazy_static", ] @@ -1635,9 +1644,9 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" [[package]] name = "sled" @@ -1653,6 +1662,7 @@ dependencies = [ "libc", "log", "parking_lot", + "rio", ] [[package]] @@ -1923,9 +1933,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "154794c8f499c2619acd19e839294703e9e32e7630ef5f46ea80d4ef0fbee5eb" +checksum = "b2dd85aeaba7b68df939bd357c6afb36c87951be9e80bf9c859f2fc3e9fca0fd" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index dcfd0b5..4c629c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] default = [] -io-uring = ["actix-rt/io-uring", "actix-server/io-uring", "tokio-uring"] +io-uring = ["actix-rt/io-uring", "actix-server/io-uring", "tokio-uring", "sled/io_uring"] [dependencies] actix-form-data = "0.6.0-beta.1" diff --git a/src/either.rs b/src/either.rs new file mode 100644 index 0000000..5bdeedd --- /dev/null +++ b/src/either.rs @@ -0,0 +1,43 @@ +use futures_util::stream::Stream; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, ReadBuf}; + +pub(crate) enum Either { + Left(Left), + Right(Right), +} + +impl AsyncRead for Either +where + Left: AsyncRead + Unpin, + Right: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match *self { + Self::Left(ref mut left) => Pin::new(left).poll_read(cx, buf), + Self::Right(ref mut right) => Pin::new(right).poll_read(cx, buf), + } + } +} + +impl Stream for Either +where + Left: Stream::Item> + Unpin, + Right: Stream + Unpin, +{ + type Item = ::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match *self { + Self::Left(ref mut left) => Pin::new(left).poll_next(cx), + Self::Right(ref mut right) => Pin::new(right).poll_next(cx), + } + } +} diff --git a/src/file.rs b/src/file.rs index 3ac1bf0..2d4c6b8 100644 --- a/src/file.rs +++ b/src/file.rs @@ -1,52 +1,159 @@ +use futures_util::stream::Stream; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + #[cfg(feature = "io-uring")] pub(crate) use io_uring::File; #[cfg(not(feature = "io-uring"))] -pub(crate) use tokio::fs::File; +pub(crate) use tokio_file::File; + +struct CrateError(S); + +impl Stream for CrateError +where + S: Stream> + Unpin, + crate::error::Error: From, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0) + .poll_next(cx) + .map(|opt| opt.map(|res| res.map_err(Into::into))) + } +} + +#[cfg(not(feature = "io-uring"))] +mod tokio_file { + use crate::Either; + use actix_web::web::{Bytes, BytesMut}; + use futures_util::stream::Stream; + use std::{fs::Metadata, io::SeekFrom, path::Path}; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; + use tokio_util::codec::{BytesCodec, FramedRead}; + + pub(crate) struct File { + inner: tokio::fs::File, + } + + impl File { + pub(crate) async fn open(path: impl AsRef) -> std::io::Result { + Ok(File { + inner: tokio::fs::File::open(path).await?, + }) + } + + pub(crate) async fn create(path: impl AsRef) -> std::io::Result { + Ok(File { + inner: tokio::fs::File::create(path).await?, + }) + } + + pub(crate) async fn metadata(&self) -> std::io::Result { + self.inner.metadata().await + } + + pub(crate) async fn write_from_bytes<'a>( + &'a mut self, + mut bytes: Bytes, + ) -> std::io::Result<()> { + self.inner.write_all_buf(&mut bytes).await?; + Ok(()) + } + + pub(crate) async fn write_from_async_read<'a, R>( + &'a mut self, + mut reader: R, + ) -> std::io::Result<()> + where + R: AsyncRead + Unpin, + { + tokio::io::copy(&mut reader, &mut self.inner).await?; + Ok(()) + } + + pub(crate) async fn read_to_async_write<'a, W>( + &'a mut self, + writer: &'a mut W, + ) -> std::io::Result<()> + where + W: AsyncWrite + Unpin, + { + tokio::io::copy(&mut self.inner, writer).await?; + Ok(()) + } + + pub(crate) async fn read_to_stream( + mut self, + from_start: Option, + len: Option, + ) -> Result< + impl Stream> + Unpin, + crate::error::Error, + > { + let obj = match (from_start, len) { + (Some(lower), Some(upper)) => { + self.inner.seek(SeekFrom::Start(lower)).await?; + Either::Left(self.inner.take(upper)) + } + (None, Some(upper)) => Either::Left(self.inner.take(upper)), + (Some(lower), None) => { + self.inner.seek(SeekFrom::Start(lower)).await?; + Either::Right(self.inner) + } + (None, None) => Either::Right(self.inner), + }; + + Ok(super::CrateError(BytesFreezer(FramedRead::new( + obj, + BytesCodec::new(), + )))) + } + } + + struct BytesFreezer(S); + + impl Stream for BytesFreezer + where + S: Stream> + Unpin, + { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0) + .poll_next(cx) + .map(|opt| opt.map(|res| res.map(BytesMut::freeze))) + } + } +} #[cfg(feature = "io-uring")] mod io_uring { + use actix_web::web::Bytes; + use futures_util::stream::Stream; use std::{ convert::TryInto, fs::Metadata, future::Future, - io::SeekFrom, path::{Path, PathBuf}, pin::Pin, - task::{Context, Poll, Waker}, + task::{Context, Poll}, + }; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + use tokio_uring::{ + buf::{IoBuf, IoBufMut, Slice}, + BufResult, }; - use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; - - type IoFuture = - Pin, Vec)>>>; - - type FlushFuture = Pin)>>>; - - type ShutdownFuture = Pin>>>; - - type SeekFuture = Pin>>>; - - enum FileState { - Reading { future: IoFuture }, - Writing { future: IoFuture }, - Syncing { future: FlushFuture }, - Seeking { future: SeekFuture }, - Shutdown { future: ShutdownFuture }, - Pending, - } - - impl FileState { - fn take(&mut self) -> Self { - std::mem::replace(self, FileState::Pending) - } - } pub(crate) struct File { path: PathBuf, - inner: Option, - cursor: usize, - wakers: Vec, - state: FileState, + inner: tokio_uring::fs::File, } impl File { @@ -54,10 +161,7 @@ mod io_uring { tracing::info!("Opening io-uring file"); Ok(File { path: path.as_ref().to_owned(), - inner: Some(tokio_uring::fs::File::open(path).await?), - cursor: 0, - wakers: vec![], - state: FileState::Pending, + inner: tokio_uring::fs::File::open(path).await?, }) } @@ -65,10 +169,7 @@ mod io_uring { tracing::info!("Creating io-uring file"); Ok(File { path: path.as_ref().to_owned(), - inner: Some(tokio_uring::fs::File::create(path).await?), - cursor: 0, - wakers: vec![], - state: FileState::Pending, + inner: tokio_uring::fs::File::create(path).await?, }) } @@ -76,347 +177,194 @@ mod io_uring { tokio::fs::metadata(&self.path).await } - fn poll_read( - &mut self, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - mut future: IoFuture, - ) -> Poll> { - match Pin::new(&mut future).poll(cx) { - Poll::Ready((file, Ok(bytes_read), vec)) => { - self.cursor += bytes_read; - self.inner = Some(file); - buf.put_slice(&vec[0..bytes_read]); + pub(crate) async fn write_from_bytes<'a>( + &'a mut self, + bytes: Bytes, + ) -> std::io::Result<()> { + let mut buf = bytes.to_vec(); + let len: u64 = buf.len().try_into().unwrap(); - // Wake tasks waiting on read to complete - for waker in self.wakers.drain(..) { - waker.wake(); - } + let mut cursor: u64 = 0; - Poll::Ready(Ok(())) - } - Poll::Ready((file, Err(err), _vec)) => { - self.inner = Some(file); - // Wake tasks waiting on read to complete - for waker in self.wakers.drain(..) { - waker.wake(); - } - - Poll::Ready(Err(err)) - } - Poll::Pending => { - self.state = FileState::Reading { future }; - - Poll::Pending - } - } - } - - fn poll_write( - &mut self, - cx: &mut Context<'_>, - mut future: IoFuture, - ) -> Poll> { - match Pin::new(&mut future).poll(cx) { - Poll::Ready((file, Ok(bytes_written), _vec)) => { - self.cursor += bytes_written; - self.inner = Some(file); - - for waker in self.wakers.drain(..) { - waker.wake(); - } - - Poll::Ready(Ok(bytes_written)) - } - Poll::Ready((file, Err(err), _vec)) => { - self.inner = Some(file); - - for waker in self.wakers.drain(..) { - waker.wake(); - } - - Poll::Ready(Err(err)) - } - Poll::Pending => { - self.state = FileState::Writing { future }; - - Poll::Pending - } - } - } - - fn poll_flush( - &mut self, - cx: &mut Context<'_>, - mut future: FlushFuture, - ) -> Poll> { - match Pin::new(&mut future).poll(cx) { - Poll::Ready((file, res)) => { - self.inner = Some(file); - - for waker in self.wakers.drain(..) { - waker.wake(); - } - - Poll::Ready(res) - } - Poll::Pending => { - self.state = FileState::Syncing { future }; - - Poll::Pending - } - } - } - - fn poll_shutdown( - &mut self, - cx: &mut Context<'_>, - mut future: ShutdownFuture, - ) -> Poll> { - match Pin::new(&mut future).poll(cx) { - Poll::Ready(res) => { - for waker in self.wakers.drain(..) { - waker.wake(); - } - - Poll::Ready(res) - } - Poll::Pending => { - self.state = FileState::Shutdown { future }; - - Poll::Pending - } - } - } - - fn poll_seek( - &mut self, - cx: &mut Context<'_>, - mut future: SeekFuture, - ) -> Poll> { - match Pin::new(&mut future).poll(cx) { - Poll::Ready(Ok(new_position)) => { - for waker in self.wakers.drain(..) { - waker.wake(); - } - - if let Ok(position) = new_position.try_into() { - self.cursor = position; - Poll::Ready(Ok(new_position)) - } else { - Poll::Ready(Err(std::io::ErrorKind::Other.into())) - } - } - Poll::Ready(Err(err)) => { - for waker in self.wakers.drain(..) { - waker.wake(); - } - - Poll::Ready(Err(err)) - } - Poll::Pending => { - self.state = FileState::Seeking { future }; - - Poll::Pending - } - } - } - - fn prepare_read(&mut self, buf: &mut ReadBuf<'_>) -> IoFuture { - let bytes_to_read = buf.remaining().min(65_536); - - let vec = vec![0u8; bytes_to_read]; - - let file = self.inner.take().unwrap(); - let position: u64 = self.cursor.try_into().unwrap(); - - Box::pin(async move { - let (res, vec) = file.read_at(vec, position).await; - (file, res, vec) - }) - } - - fn prepare_write(&mut self, buf: &[u8]) -> IoFuture { - let vec = buf.to_vec(); - - let file = self.inner.take().unwrap(); - let position: u64 = self.cursor.try_into().unwrap(); - - Box::pin(async move { - let (res, vec) = file.write_at(vec, position).await; - (file, res, vec) - }) - } - - fn prepare_flush(&mut self) -> FlushFuture { - let file = self.inner.take().unwrap(); - - Box::pin(async move { - let res = file.sync_all().await; - (file, res) - }) - } - - fn prepare_shutdown(&mut self) -> ShutdownFuture { - let file = self.inner.take().unwrap(); - - Box::pin(async move { - file.sync_all().await?; - file.close().await - }) - } - - fn prepare_seek(&self, from_end: i64) -> SeekFuture { - let path = self.path.clone(); - - Box::pin(async move { - let meta = tokio::fs::metadata(path).await?; - let end = meta.len(); - - if from_end < 0 { - let from_end = (-1) * from_end; - let from_end: u64 = - from_end.try_into().map_err(|_| std::io::ErrorKind::Other)?; - - return Ok(end + from_end); + loop { + if cursor == len { + break; } - let from_end: u64 = from_end.try_into().map_err(|_| std::io::ErrorKind::Other)?; + let cursor_usize: usize = cursor.try_into().unwrap(); + let (res, slice) = self.inner.write_at(buf.slice(cursor_usize..), cursor).await; + let n: usize = res?; - if from_end > end { - return Err(std::io::ErrorKind::Other.into()); + if n == 0 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); } - Ok(end - from_end) - }) - } - - fn register_waker(&mut self, cx: &mut Context<'_>) -> Poll { - let already_registered = self.wakers.iter().any(|waker| cx.waker().will_wake(waker)); - - if !already_registered { - self.wakers.push(cx.waker().clone()); + buf = slice.into_inner(); + let n: u64 = n.try_into().unwrap(); + cursor += n; } - Poll::Pending + Ok(()) + } + + pub(crate) async fn write_from_async_read<'a, R>( + &'a mut self, + mut reader: R, + ) -> std::io::Result<()> + where + R: AsyncRead + Unpin, + { + let metadata = self.metadata().await?; + let size = metadata.len(); + + let mut cursor: u64 = 0; + + loop { + let max_size = (size - cursor).min(65_536); + let mut buf = Vec::with_capacity(max_size.try_into().unwrap()); + + let n = (&mut reader).take(max_size).read_to_end(&mut buf).await?; + + if n == 0 { + break; + } + + let mut buf: Slice> = buf.slice(..n); + let mut position = 0; + + loop { + if position == buf.len() { + break; + } + + let (res, slice) = self.write_at(buf.slice(position..), cursor).await; + position += res?; + + buf = slice.into_inner(); + } + + let position: u64 = position.try_into().unwrap(); + cursor += position; + } + + Ok(()) + } + + pub(crate) async fn read_to_async_write<'a, W>( + &'a mut self, + writer: &mut W, + ) -> std::io::Result<()> + where + W: AsyncWrite + Unpin, + { + let metadata = self.metadata().await?; + let size = metadata.len(); + + let mut cursor: u64 = 0; + + loop { + if cursor == size { + break; + } + + let max_size = (size - cursor).min(65_536); + let buf = Vec::with_capacity(max_size.try_into().unwrap()); + + let (res, mut buf): (_, Vec) = self.read_at(buf, cursor).await; + let n: usize = res?; + + if n == 0 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + + writer.write_all(&mut buf[0..n]).await?; + + let n: u64 = n.try_into().unwrap(); + cursor += n; + } + + Ok(()) + } + + pub(crate) async fn read_to_stream( + self, + from_start: Option, + len: Option, + ) -> Result< + impl Stream> + Unpin, + crate::error::Error, + > { + let size = self.metadata().await?.len(); + + let cursor = from_start.unwrap_or(0); + let size = len.unwrap_or(size - cursor) + cursor; + + Ok(super::CrateError(BytesStream { + file: Some(self), + size, + cursor, + fut: None, + })) + } + + async fn read_at(&self, buf: T, pos: u64) -> BufResult { + self.inner.read_at(buf, pos).await + } + + async fn write_at(&self, buf: T, pos: u64) -> BufResult { + self.inner.write_at(buf, pos).await } } - impl AsyncRead for File { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.state.take() { - FileState::Pending => { - let future = (*self).prepare_read(buf); - - (*self).poll_read(cx, buf, future) - } - FileState::Reading { future } => (*self).poll_read(cx, buf, future), - _ => (*self).register_waker(cx), - } - } + struct BytesStream { + file: Option, + size: u64, + cursor: u64, + fut: Option>)>>>>, } - impl AsyncWrite for File { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.state.take() { - FileState::Pending => { - let future = (*self).prepare_write(buf); + impl Stream for BytesStream { + type Item = std::io::Result; - (*self).poll_write(cx, future) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut fut = if let Some(fut) = self.fut.take() { + fut + } else { + let file = self.file.take().unwrap(); + + if self.cursor == self.size { + return Poll::Ready(None); } - FileState::Writing { future } => (*self).poll_write(cx, future), - _ => (*self).register_waker(cx), - } - } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state.take() { - FileState::Pending => { - let future = (*self).prepare_flush(); + let cursor = self.cursor; + let max_size = self.size - self.cursor; - (*self).poll_flush(cx, future) + Box::pin(async move { + let buf = Vec::with_capacity(max_size.try_into().unwrap()); + + let buf_res = file.read_at(buf, cursor).await; + + (file, buf_res) + }) + }; + + match Pin::new(&mut fut).poll(cx) { + Poll::Pending => { + self.fut = Some(fut); + Poll::Pending } - FileState::Syncing { future } => (*self).poll_flush(cx, future), - _ => (*self).register_waker(cx), - } - } + Poll::Ready((file, (Ok(n), mut buf))) => { + self.file = Some(file); + let _ = buf.split_off(n); + let n: u64 = match n.try_into() { + Ok(n) => n, + Err(_) => return Poll::Ready(Some(Err(std::io::ErrorKind::Other.into()))), + }; + self.cursor += n; - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.state.take() { - FileState::Pending => { - let future = (*self).prepare_shutdown(); - - (*self).poll_shutdown(cx, future) + Poll::Ready(Some(Ok(Bytes::from(buf)))) } - FileState::Shutdown { future } => (*self).poll_shutdown(cx, future), - _ => (*self).register_waker(cx), - } - } - } - - impl AsyncSeek for File { - fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> { - match position { - SeekFrom::Start(from_start) => { - self.cursor = from_start.try_into().unwrap(); - Ok(()) - } - SeekFrom::End(from_end) => match self.state.take() { - FileState::Pending => { - let future = self.prepare_seek(from_end); - - self.state = FileState::Seeking { future }; - Ok(()) - } - _ => Err(std::io::ErrorKind::Other.into()), - }, - SeekFrom::Current(from_current) => { - if from_current < 0 { - let to_subtract = (-1) * from_current; - let to_subtract: usize = to_subtract - .try_into() - .map_err(|_| std::io::ErrorKind::Other)?; - - if to_subtract > self.cursor { - return Err(std::io::ErrorKind::Other.into()); - } - - self.cursor -= to_subtract; - } else { - let from_current: usize = from_current - .try_into() - .map_err(|_| std::io::ErrorKind::Other)?; - - self.cursor += from_current; - } - - Ok(()) - } - } - } - - fn poll_complete( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match self.state.take() { - FileState::Pending => Poll::Ready(Ok(self - .cursor - .try_into() - .map_err(|_| std::io::ErrorKind::Other)?)), - FileState::Seeking { future } => (*self).poll_seek(cx, future), - _ => Poll::Ready(Err(std::io::ErrorKind::Other.into())), + Poll::Ready((_, (Err(e), _))) => Poll::Ready(Some(Err(e))), } } } diff --git a/src/magick.rs b/src/magick.rs index 65ab524..40fc55c 100644 --- a/src/magick.rs +++ b/src/magick.rs @@ -138,8 +138,8 @@ pub(crate) async fn input_type_bytes(input: Bytes) -> Result, format: Format, ) -> std::io::Result { @@ -154,7 +154,7 @@ pub(crate) fn process_image_write_read( .arg(last_arg), )?; - Ok(process.write_read(input).unwrap()) + Ok(process.file_read(input).unwrap()) } impl Details { diff --git a/src/main.rs b/src/main.rs index 5ef6df4..5695eaa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,10 +6,7 @@ use actix_web::{ }; use awc::Client; use dashmap::{mapref::entry::Entry, DashMap}; -use futures_util::{ - stream::{once, LocalBoxStream}, - Stream, -}; +use futures_util::{stream::once, Stream}; use once_cell::sync::Lazy; use opentelemetry::{ sdk::{propagation::TraceContextPropagator, Resource}, @@ -26,7 +23,7 @@ use std::{ }; use structopt::StructOpt; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, + io::AsyncReadExt, sync::{ oneshot::{Receiver, Sender}, Semaphore, @@ -41,6 +38,7 @@ use tracing_log::LogTracer; use tracing_subscriber::{fmt::format::FmtSpan, layer::SubscriberExt, EnvFilter, Registry}; mod config; +mod either; mod error; mod exiftool; mod ffmpeg; @@ -56,6 +54,7 @@ mod validate; use self::{ config::{Config, Format}, + either::Either, error::{Error, UploadError}, middleware::{Deadline, Internal}, upload_manager::{Details, UploadManager, UploadManagerSession}, @@ -213,7 +212,7 @@ where // Try writing to a file #[instrument(name = "Saving file", skip(bytes))] -async fn safe_save_file(path: PathBuf, mut bytes: web::Bytes) -> Result<(), Error> { +async fn safe_save_file(path: PathBuf, bytes: web::Bytes) -> Result<(), Error> { if let Some(path) = path.parent() { // create the directory for the file debug!("Creating directory {:?}", path); @@ -236,7 +235,7 @@ async fn safe_save_file(path: PathBuf, mut bytes: web::Bytes) -> Result<(), Erro // try writing debug!("Writing to {:?}", path); - if let Err(e) = file.write_all_buf(&mut bytes).await { + if let Err(e) = file.write_from_bytes(bytes).await { error!("Error writing {:?}, {}", path, e); // remove file if writing failed before completion tokio::fs::remove_file(path).await?; @@ -545,7 +544,7 @@ async fn process( let file = crate::file::File::open(original_path.clone()).await?; let mut processed_reader = - crate::magick::process_image_write_read(file, thumbnail_args, format)?; + crate::magick::process_image_file_read(file, thumbnail_args, format)?; let mut vec = Vec::new(); processed_reader.read_to_end(&mut vec).await?; @@ -718,7 +717,7 @@ async fn ranged_file_resp( let mut builder = HttpResponse::PartialContent(); builder.insert_header(range.to_content_range(meta.len())); - (builder, range.chop_file(file).await?) + (builder, Either::Left(range.chop_file(file).await?)) } else { return Err(UploadError::Range.into()); } @@ -726,8 +725,8 @@ async fn ranged_file_resp( //No Range header in the request - return the entire document None => { let file = crate::file::File::open(path).await?; - let stream = Box::pin(crate::stream::bytes_stream(file)) as LocalBoxStream<'_, _>; - (HttpResponse::Ok(), stream) + let stream = file.read_to_stream(None, None).await?; + (HttpResponse::Ok(), Either::Right(stream)) } }; diff --git a/src/range.rs b/src/range.rs index c3c2ffa..5353d09 100644 --- a/src/range.rs +++ b/src/range.rs @@ -1,7 +1,4 @@ -use crate::{ - error::{Error, UploadError}, - stream::bytes_stream, -}; +use crate::error::{Error, UploadError}; use actix_web::{ dev::Payload, http::{ @@ -11,9 +8,8 @@ use actix_web::{ web::Bytes, FromRequest, HttpRequest, }; -use futures_util::stream::{once, LocalBoxStream, Stream}; -use std::{future::ready, io}; -use tokio::io::{AsyncReadExt, AsyncSeekExt}; +use futures_util::stream::{once, Stream}; +use std::future::ready; #[derive(Debug)] pub(crate) enum Range { @@ -61,25 +57,14 @@ impl Range { pub(crate) async fn chop_file( &self, - mut file: crate::file::File, - ) -> Result>, Error> { + file: crate::file::File, + ) -> Result> + Unpin, Error> { match self { - Range::Start(start) => { - file.seek(io::SeekFrom::Start(*start)).await?; - - Ok(Box::pin(bytes_stream(file))) - } - Range::SuffixLength(from_start) => { - file.seek(io::SeekFrom::Start(0)).await?; - let reader = file.take(*from_start); - - Ok(Box::pin(bytes_stream(reader))) - } + Range::Start(start) => file.read_to_stream(Some(*start), None).await, + Range::SuffixLength(from_start) => file.read_to_stream(None, Some(*from_start)).await, Range::Segment(start, end) => { - file.seek(io::SeekFrom::Start(*start)).await?; - let reader = file.take(end.saturating_sub(*start)); - - Ok(Box::pin(bytes_stream(reader))) + file.read_to_stream(Some(*start), Some(end.saturating_sub(*start))) + .await } } } diff --git a/src/stream.rs b/src/stream.rs index 434039f..33135dc 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,7 +1,5 @@ -use crate::error::Error; use actix_rt::task::JoinHandle; -use actix_web::web::{Bytes, BytesMut}; -use futures_util::Stream; +use actix_web::web::Bytes; use std::{ future::Future, pin::Pin, @@ -13,7 +11,6 @@ use tokio::{ process::{Child, Command}, sync::oneshot::{channel, Receiver}, }; -use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::Instrument; use tracing::Span; @@ -33,8 +30,6 @@ pub(crate) struct ProcessRead { handle: JoinHandle<()>, } -struct BytesFreezer(S, Span); - impl Process { pub(crate) fn run(command: &str, args: &[&str]) -> std::io::Result { Self::spawn(Command::new(command).args(args)) @@ -100,9 +95,9 @@ impl Process { })) } - pub(crate) fn write_read( + pub(crate) fn file_read( mut self, - mut input_reader: impl AsyncRead + Unpin + 'static, + mut input_file: crate::file::File, ) -> Option { let mut stdin = self.child.stdin.take()?; let stdout = self.child.stdout.take()?; @@ -113,7 +108,7 @@ impl Process { let mut child = self.child; let handle = actix_rt::spawn( async move { - if let Err(e) = tokio::io::copy(&mut input_reader, &mut stdin).await { + if let Err(e) = input_file.read_to_async_write(&mut stdin).await { let _ = tx.send(e); return; } @@ -144,15 +139,6 @@ impl Process { } } -pub(crate) fn bytes_stream( - input: impl AsyncRead + Unpin, -) -> impl Stream> + Unpin { - BytesFreezer( - FramedRead::new(input, BytesCodec::new()), - tracing::info_span!("Serving bytes from AsyncRead"), - ) -} - impl AsyncRead for ProcessRead where I: AsyncRead + Unpin, @@ -198,23 +184,6 @@ impl Drop for ProcessRead { } } -impl Stream for BytesFreezer -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let span = self.1.clone(); - span.in_scope(|| { - Pin::new(&mut self.0) - .poll_next(cx) - .map(|opt| opt.map(|res| res.map(|bytes_mut| bytes_mut.freeze()))) - .map_err(Error::from) - }) - } -} - impl std::fmt::Display for StatusError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "Command failed with bad status") diff --git a/src/upload_manager.rs b/src/upload_manager.rs index 2987e97..1a98333 100644 --- a/src/upload_manager.rs +++ b/src/upload_manager.rs @@ -13,7 +13,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWriteExt, ReadBuf}; +use tokio::io::{AsyncRead, ReadBuf}; use tracing::{debug, error, info, instrument, warn, Span}; use tracing_futures::Instrument; @@ -993,53 +993,7 @@ pub(crate) async fn safe_save_reader( let mut file = crate::file::File::create(to).await?; - tokio::io::copy(input, &mut file).await?; - - Ok(()) -} - -#[instrument(skip(stream))] -pub(crate) async fn safe_save_stream( - to: PathBuf, - mut stream: UploadStream, -) -> Result<(), Error> -where - Error: From, - E: Unpin, -{ - if let Some(path) = to.parent() { - debug!("Creating directory {:?}", path); - tokio::fs::create_dir_all(path).await?; - } - - debug!("Checking if {:?} already exists", to); - if let Err(e) = tokio::fs::metadata(&to).await { - if e.kind() != std::io::ErrorKind::NotFound { - return Err(e.into()); - } - } else { - return Err(UploadError::FileExists.into()); - } - - debug!("Writing stream to {:?}", to); - - let to1 = to.clone(); - let fut = async move { - let mut file = crate::file::File::create(to1).await?; - - while let Some(res) = stream.next().await { - let mut bytes = res?; - file.write_all_buf(&mut bytes).await?; - } - - Ok(()) - }; - - if let Err(e) = fut.await { - error!("Failed to save file: {}", e); - let _ = tokio::fs::remove_file(to).await; - return Err(e); - } + file.write_from_async_read(input).await?; Ok(()) } @@ -1107,7 +1061,7 @@ mod test { #[test] fn hasher_works() { let hash = test_on_arbiter!(async move { - let file1 = crate::file::File::open("./client-examples/earth.gif").await?; + let file1 = tokio::fs::File::open("./client-examples/earth.gif").await?; let mut hasher = Hasher::new(file1, Sha256::new());