payload middleware: switch to Rc, always inject if Payload isn't None

This commit is contained in:
asonix 2023-09-30 17:24:48 -05:00
parent 3267fb8301
commit e3462f6664

View file

@ -1,11 +1,10 @@
use std::{ use std::{
future::{ready, Ready}, future::{ready, Ready},
sync::Arc, rc::Rc,
}; };
use actix_web::{ use actix_web::{
dev::{Service, ServiceRequest, Transform}, dev::{Service, ServiceRequest, Transform},
http::Method,
HttpMessage, HttpMessage,
}; };
use streem::IntoStreamer; use streem::IntoStreamer;
@ -48,7 +47,7 @@ async fn drain(rx: flume::Receiver<actix_web::dev::Payload>) {
} }
#[derive(Clone)] #[derive(Clone)]
struct DrainHandle(Option<Arc<actix_web::rt::task::JoinHandle<()>>>); struct DrainHandle(Option<Rc<actix_web::rt::task::JoinHandle<()>>>);
pub(crate) struct Payload { pub(crate) struct Payload {
sender: flume::Sender<actix_web::dev::Payload>, sender: flume::Sender<actix_web::dev::Payload>,
@ -67,7 +66,7 @@ pub(crate) struct PayloadStream {
impl DrainHandle { impl DrainHandle {
fn new(handle: actix_web::rt::task::JoinHandle<()>) -> Self { fn new(handle: actix_web::rt::task::JoinHandle<()>) -> Self {
Self(Some(Arc::new(handle))) Self(Some(Rc::new(handle)))
} }
} }
@ -75,7 +74,7 @@ impl Payload {
pub(crate) fn new() -> Self { pub(crate) fn new() -> Self {
let (tx, rx) = crate::sync::channel(LIMIT); let (tx, rx) = crate::sync::channel(LIMIT);
let handle = DrainHandle::new(crate::sync::spawn(async move { drain(rx).await })); let handle = DrainHandle::new(crate::sync::spawn(drain(rx)));
Payload { sender: tx, handle } Payload { sender: tx, handle }
} }
@ -83,7 +82,7 @@ impl Payload {
impl Drop for DrainHandle { impl Drop for DrainHandle {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(handle) = self.0.take().and_then(Arc::into_inner) { if let Some(handle) = self.0.take().and_then(Rc::into_inner) {
handle.abort(); handle.abort();
} }
} }
@ -166,8 +165,9 @@ where
} }
fn call(&self, mut req: ServiceRequest) -> Self::Future { fn call(&self, mut req: ServiceRequest) -> Self::Future {
if matches!(*req.method(), Method::POST | Method::PATCH | Method::PUT) { let payload = req.take_payload();
let payload = req.take_payload();
if !matches!(payload, actix_web::dev::Payload::None) {
let payload: LocalBoxStream<'static, _> = Box::pin(PayloadStream { let payload: LocalBoxStream<'static, _> = Box::pin(PayloadStream {
inner: Some(payload), inner: Some(payload),
sender: self.sender.clone(), sender: self.sender.clone(),