diff --git a/actix-web/src/middleware/compress.rs b/actix-web/src/middleware/compress.rs index dce9fc378..de6219e2b 100644 --- a/actix-web/src/middleware/compress.rs +++ b/actix-web/src/middleware/compress.rs @@ -1,6 +1,7 @@ //! For middleware documentation, see [`Compress`]. use std::{ + fmt, future::Future, marker::PhantomData, pin::Pin, @@ -11,14 +12,14 @@ use actix_http::encoding::Encoder; use actix_service::{Service, Transform}; use actix_utils::future::{ok, Either, Ready}; use futures_core::ready; +use mime::Mime; use once_cell::sync::Lazy; use pin_project_lite::pin_project; -use mime::Mime; use crate::{ body::{EitherBody, MessageBody}, http::{ - header::{self, AcceptEncoding, ContentType, Encoding, HeaderValue}, + header::{self, AcceptEncoding, ContentEncoding, Encoding, HeaderValue}, StatusCode, }, service::{ServiceRequest, ServiceResponse}, @@ -72,17 +73,22 @@ use crate::{ /// ``` /// /// [feature flags]: ../index.html#crate-features -#[derive(Debug, Clone)] +#[derive(Clone)] #[non_exhaustive] pub struct Compress { - pub compress: fn(Mime) -> bool, + pub compress: fn(&HeaderValue) -> bool, } +impl fmt::Debug for Compress { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Compress").finish() + } +} impl Default for Compress { fn default() -> Self { - Compress { - compress: |_| { true } - } + Compress { + compress: |_| false, + } } } @@ -98,13 +104,16 @@ where type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ok(CompressMiddleware { service, compress: self.compress }) + ok(CompressMiddleware { + service, + compress: self.compress, + }) } } pub struct CompressMiddleware { service: S, - compress: fn(Mime) -> bool, + compress: fn(&HeaderValue) -> bool, } impl Service for CompressMiddleware @@ -131,6 +140,7 @@ where encoding: Encoding::identity(), fut: self.service.call(req), _phantom: PhantomData, + compress: self.compress, }) } @@ -158,6 +168,7 @@ where fut: self.service.call(req), encoding, _phantom: PhantomData, + compress: self.compress, }), } } @@ -172,6 +183,7 @@ pin_project! { fut: S::Future, encoding: Encoding, _phantom: PhantomData, + compress: fn(&HeaderValue) -> bool, } } @@ -182,8 +194,8 @@ where { type Output = Result>>, Error>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); match ready!(this.fut.poll(cx)) { Ok(resp) => { @@ -195,7 +207,19 @@ where }; Poll::Ready(Ok(resp.map_body(move |head, body| { - EitherBody::left(Encoder::response(enc, head, body)) + let content_type = head.headers.get(header::CONTENT_TYPE); + let should_compress = content_type + .map(|value| (self.compress)(value)) + .unwrap_or(true); + if should_compress { + EitherBody::left(Encoder::response(enc, head, body)) + } else { + EitherBody::left(Encoder::response( + ContentEncoding::Identity, + head, + body, + )) + } }))) } @@ -259,6 +283,7 @@ mod tests { use std::collections::HashSet; use super::*; + use crate::http::header::ContentType; use crate::{middleware::DefaultHeaders, test, web, App}; pub fn gzip_decode(bytes: impl AsRef<[u8]>) -> Vec { @@ -345,9 +370,9 @@ mod tests { App::new().wrap(Compress::default()).route( "/image", web::get().to(move || { - let mut builder = HttpResponse::Ok(); - builder.body(DATA); - builder.insert_header(ContentType::jpeg()); + let builder = HttpResponse::Ok() + .insert_header(ContentType::jpeg()) + .body(DATA); builder }), ) @@ -357,9 +382,12 @@ mod tests { .uri("/image") .insert_header((header::ACCEPT_ENCODING, "gzip")) .to_request(); - let res = test::call_service(&app, req).await; + let res = test::call_service(&app, req).await; assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.headers().get(header::CONTENT_TYPE).unwrap(), "gzip"); + assert_eq!( + res.headers().get(header::CONTENT_TYPE).unwrap(), + "image/jpeg" + ); let bytes = test::read_body(res).await; assert_eq!(bytes, DATA.as_bytes()); }