diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index 77e34bcdc..06fccdb69 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -6,6 +6,7 @@ use std::{ slice::from_raw_parts_mut, }; +use ahash::AHashMap; use bytes::{BufMut, BytesMut}; use crate::{ @@ -109,28 +110,21 @@ pub(crate) trait MessageType: Sized { BodySize::None => dst.put_slice(b"\r\n"), } - // Connection - match conn_type { - ConnectionType::Upgrade => dst.put_slice(b"connection: upgrade\r\n"), - ConnectionType::KeepAlive if version < Version::HTTP_11 => { - if camel_case { - dst.put_slice(b"Connection: keep-alive\r\n") - } else { - dst.put_slice(b"connection: keep-alive\r\n") - } - } - ConnectionType::Close if version >= Version::HTTP_11 => { - if camel_case { - dst.put_slice(b"Connection: close\r\n") - } else { - dst.put_slice(b"connection: close\r\n") - } - } - _ => {} - } + let headers = match self.extra_headers() { + Some(extra_headers) => self + .headers() + .inner + .iter() + .filter(|(name, _)| !extra_headers.contains_key(*name)) + .chain(extra_headers.inner.iter()) + .collect::>(), + None => self.headers().inner.iter().collect::>(), + }; + + // write connection header + self.write_connection_header(&headers, conn_type, version, dst); // write headers - let mut has_date = false; let mut buf = dst.chunk_mut().as_mut_ptr(); @@ -141,7 +135,7 @@ pub(crate) trait MessageType: Sized { // container's knowledge, this is used to sync the containers cursor after data is written let mut pos = 0; - self.write_headers(|key, value| { + self.write_headers(&headers, |key, value| { match *key { CONNECTION => return, TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return, @@ -221,22 +215,54 @@ pub(crate) trait MessageType: Sized { Ok(()) } - fn write_headers(&mut self, mut f: F) + fn write_connection_header( + &self, + headers: &AHashMap<&HeaderName, &Value>, + conn_type: ConnectionType, + version: Version, + buf: &mut B, + ) { + let camel_case = self.camel_case(); + + if let Some(header_value) = headers.get(&CONNECTION) { + if camel_case { + buf.put_slice(b"Connection: "); + } else { + buf.put_slice(b"connection: "); + } + for val in header_value.iter() { + buf.put_slice(val.as_ref()); + } + buf.put_slice(b"\r\n"); + return; + } + + // Connection + match conn_type { + ConnectionType::Upgrade => buf.put_slice(b"connection: upgrade\r\n"), + ConnectionType::KeepAlive if version < Version::HTTP_11 => { + if camel_case { + buf.put_slice(b"Connection: keep-alive\r\n") + } else { + buf.put_slice(b"connection: keep-alive\r\n") + } + } + ConnectionType::Close if version >= Version::HTTP_11 => { + if camel_case { + buf.put_slice(b"Connection: close\r\n") + } else { + buf.put_slice(b"connection: close\r\n") + } + } + _ => {} + } + } + + fn write_headers(&self, headers: &AHashMap<&HeaderName, &Value>, mut f: F) where F: FnMut(&HeaderName, &Value), { - match self.extra_headers() { - Some(headers) => { - // merging headers from head and extra headers. - self.headers() - .inner - .iter() - .filter(|(name, _)| !headers.contains_key(*name)) - .chain(headers.inner.iter()) - .for_each(|(k, v)| f(k, v)) - } - None => self.headers().inner.iter().for_each(|(k, v)| f(k, v)), - } + headers.iter().for_each(|(key, value)| f(key, value)); } } @@ -668,4 +694,61 @@ mod tests { assert!(!data.contains("content-length: 0\r\n")); assert!(!data.contains("transfer-encoding: chunked\r\n")); } + + #[actix_rt::test] + async fn test_close_connection_header_even_keep_alive_was_provided() { + let mut bytes = BytesMut::with_capacity(2048); + + let mut res = Response::with_body(StatusCode::OK, ()); + res.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("close")); + + let _ = res.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Stream, + ConnectionType::KeepAlive, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("connection: close\r\n")); + } + + #[actix_rt::test] + async fn test_keep_alive_connection_header_when_provided() { + let mut bytes = BytesMut::with_capacity(2048); + + let mut res = Response::with_body(StatusCode::OK, ()); + res.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + + let _ = res.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Stream, + ConnectionType::KeepAlive, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("connection: keep-alive\r\n")); + } + + #[actix_rt::test] + async fn test_keep_alive_connection_header_even_close_was_provided() { + let mut bytes = BytesMut::with_capacity(2048); + + let mut res = Response::with_body(StatusCode::OK, ()); + res.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("keep-alive")); + + let _ = res.encode_headers( + &mut bytes, + Version::HTTP_11, + BodySize::Stream, + ConnectionType::Close, + &ServiceConfig::default(), + ); + let data = String::from_utf8(Vec::from(bytes.split().freeze().as_ref())).unwrap(); + assert!(data.contains("connection: keep-alive\r\n")); + } }