From 35cadbbe0b2be6cd0d5a665d2b2fa6e2973d91a8 Mon Sep 17 00:00:00 2001 From: asonix Date: Sat, 18 May 2024 12:30:19 -0500 Subject: [PATCH] actix-http: play with allowing bytes passthrough in the h1 encoder --- actix-http/src/h1/big_bytes.rs | 111 +++++++++++ actix-http/src/h1/codec.rs | 34 +++- actix-http/src/h1/dispatcher.rs | 38 ++-- actix-http/src/h1/encoder.rs | 323 +++++++++++++++++++++++++++++++- actix-http/src/h1/mod.rs | 1 + 5 files changed, 478 insertions(+), 29 deletions(-) create mode 100644 actix-http/src/h1/big_bytes.rs diff --git a/actix-http/src/h1/big_bytes.rs b/actix-http/src/h1/big_bytes.rs new file mode 100644 index 000000000..ef31a61de --- /dev/null +++ b/actix-http/src/h1/big_bytes.rs @@ -0,0 +1,111 @@ +use std::collections::VecDeque; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +const SIXTYFOUR_KB: usize = 1024 * 64; + +pub(super) struct BigBytes { + buffer: BytesMut, + frozen: VecDeque, + frozen_len: usize, +} + +impl BigBytes { + pub(super) fn with_capacity(capacity: usize) -> Self { + Self { + buffer: BytesMut::with_capacity(capacity), + frozen: VecDeque::default(), + frozen_len: 0, + } + } + + // Clear the internal queue and buffer, resetting length to zero + pub(super) fn clear(&mut self) { + std::mem::take(&mut self.frozen); + self.frozen_len = 0; + self.buffer.clear(); + } + + // Return a mutable reference to the underlying buffer. This should only be used when dealing + // with small allocations (e.g. writing headers) + pub(super) fn buffer_mut(&mut self) -> &mut BytesMut { + &mut self.buffer + } + + // Reserve the requested size, if fewer than 64KB + pub(super) fn reserve(&mut self, count: usize) { + if count < SIXTYFOUR_KB { + self.buffer.reserve(count); + } + } + + pub(super) fn total_len(&mut self) -> usize { + self.frozen_len + self.buffer.len() + } + + pub(super) fn is_empty(&self) -> bool { + self.frozen_len == 0 && self.buffer.is_empty() + } + + // Add the `bytes` to the internal structure. If `bytes` exceeds 64KB, it is pushed into a + // queue, otherwise, it is added to a buffer. + pub(super) fn put_bytes(&mut self, bytes: Bytes) { + if bytes.len() < SIXTYFOUR_KB { + self.buffer.extend_from_slice(&bytes); + } else { + if !self.buffer.is_empty() { + let current = self.buffer.split().freeze(); + self.frozen_len += current.len(); + self.frozen.push_back(current); + } + + self.frozen_len += bytes.len(); + self.frozen.push_back(bytes); + } + } + + // Put a slice into the internal structure. This is always added to the internal buffer + pub(super) fn extend_from_slice(&mut self, slice: &[u8]) { + self.buffer.extend_from_slice(slice); + } + + // Returns a slice of the frontmost buffer + pub(super) fn front_slice(&self) -> &[u8] { + if let Some(front) = self.frozen.front() { + &front + } else { + &self.buffer + } + } + + // Advances the first buffer by `count` bytes. If the first buffer is advanced to completion, + // it is popped from the queue + pub(super) fn advance(&mut self, count: usize) { + if let Some(front) = self.frozen.front_mut() { + front.advance(count); + + if front.is_empty() { + self.frozen.pop_front(); + } + + self.frozen_len -= count; + } else { + self.buffer.advance(count); + } + } + + // Drain the BibBytes, writing everything into the provided BytesMut + pub(super) fn write_to(&mut self, dst: &mut BytesMut) { + dst.reserve(self.total_len()); + + for buf in &self.frozen { + dst.put_slice(buf); + } + + dst.put_slice(&self.buffer.split()); + + self.frozen_len = 0; + + std::mem::take(&mut self.frozen); + } +} diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs index 2b452f8f8..a6ac3f895 100644 --- a/actix-http/src/h1/codec.rs +++ b/actix-http/src/h1/codec.rs @@ -6,6 +6,7 @@ use http::{Method, Version}; use tokio_util::codec::{Decoder, Encoder}; use super::{ + big_bytes::BigBytes, decoder::{self, PayloadDecoder, PayloadItem, PayloadType}, encoder, Message, MessageType, }; @@ -146,14 +147,12 @@ impl Decoder for Codec { } } -impl Encoder, BodySize)>> for Codec { - type Error = io::Error; - - fn encode( +impl Codec { + pub(super) fn encode_bigbytes( &mut self, item: Message<(Response<()>, BodySize)>, - dst: &mut BytesMut, - ) -> Result<(), Self::Error> { + dst: &mut BigBytes, + ) -> std::io::Result<()> { match item { Message::Item((mut res, length)) => { // set response version @@ -171,7 +170,7 @@ impl Encoder, BodySize)>> for Codec { }; // encode message - self.encoder.encode( + self.encoder.encode_bigbytes( dst, &mut res, self.flags.contains(Flags::HEAD), @@ -184,11 +183,11 @@ impl Encoder, BodySize)>> for Codec { } Message::Chunk(Some(bytes)) => { - self.encoder.encode_chunk(bytes.as_ref(), dst)?; + self.encoder.encode_chunk_bigbytes(bytes, dst)?; } Message::Chunk(None) => { - self.encoder.encode_eof(dst)?; + self.encoder.encode_eof_bigbytes(dst)?; } } @@ -196,6 +195,23 @@ impl Encoder, BodySize)>> for Codec { } } +impl Encoder, BodySize)>> for Codec { + type Error = io::Error; + + fn encode( + &mut self, + item: Message<(Response<()>, BodySize)>, + dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + let mut bigbytes = BigBytes::with_capacity(1024 * 8); + self.encode_bigbytes(item, &mut bigbytes)?; + + bigbytes.write_to(dst); + + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 00b51360e..30d8c61d7 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -12,14 +12,15 @@ use actix_codec::{Framed, FramedParts}; use actix_rt::time::sleep_until; use actix_service::Service; use bitflags::bitflags; -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use futures_core::ready; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::{Decoder as _, Encoder as _}; +use tokio_util::codec::Decoder as _; use tracing::{error, trace}; use super::{ + big_bytes::BigBytes, codec::Codec, decoder::MAX_BUFFER_SIZE, payload::{Payload, PayloadSender, PayloadStatus}, @@ -165,7 +166,7 @@ pin_project! { pub(super) io: Option, read_buf: BytesMut, - write_buf: BytesMut, + write_buf: BigBytes, codec: Codec, } } @@ -277,7 +278,7 @@ where io: Some(io), read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), - write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE), + write_buf: BigBytes::with_capacity(HW_BUFFER_SIZE), codec: Codec::new(config), }, }, @@ -329,20 +330,17 @@ where let InnerDispatcherProj { io, write_buf, .. } = self.project(); let mut io = Pin::new(io.as_mut().unwrap()); - let len = write_buf.len(); - let mut written = 0; - - while written < len { - match io.as_mut().poll_write(cx, &write_buf[written..])? { + while write_buf.total_len() > 0 { + match io.as_mut().poll_write(cx, write_buf.front_slice())? { Poll::Ready(0) => { + println!("WRITE ZERO"); error!("write zero; closing"); return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, ""))); } - Poll::Ready(n) => written += n, + Poll::Ready(n) => write_buf.advance(n), Poll::Pending => { - write_buf.advance(written); return Poll::Pending; } } @@ -365,7 +363,7 @@ where let size = body.size(); this.codec - .encode(Message::Item((res, size)), this.write_buf) + .encode_bigbytes(Message::Item((res, size)), this.write_buf) .map_err(|err| { if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); @@ -493,15 +491,16 @@ where StateProj::SendPayload { mut body } => { // keep populate writer buffer until buffer size limit hit, // get blocked or finished. - while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { + while this.write_buf.total_len() < super::payload::MAX_BUFFER_SIZE { match body.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(item))) => { this.codec - .encode(Message::Chunk(Some(item)), this.write_buf)?; + .encode_bigbytes(Message::Chunk(Some(item)), this.write_buf)?; } Poll::Ready(None) => { - this.codec.encode(Message::Chunk(None), this.write_buf)?; + this.codec + .encode_bigbytes(Message::Chunk(None), this.write_buf)?; // payload stream finished. // set state to None and handle next message @@ -532,15 +531,16 @@ where // keep populate writer buffer until buffer size limit hit, // get blocked or finished. - while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE { + while this.write_buf.total_len() < super::payload::MAX_BUFFER_SIZE { match body.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(item))) => { this.codec - .encode(Message::Chunk(Some(item)), this.write_buf)?; + .encode_bigbytes(Message::Chunk(Some(item)), this.write_buf)?; } Poll::Ready(None) => { - this.codec.encode(Message::Chunk(None), this.write_buf)?; + this.codec + .encode_bigbytes(Message::Chunk(None), this.write_buf)?; // payload stream finished // set state to None and handle next message @@ -1027,7 +1027,7 @@ where mem::take(this.codec), mem::take(this.read_buf), ); - parts.write_buf = mem::take(this.write_buf); + this.write_buf.write_to(&mut parts.write_buf); let framed = Framed::from_parts(parts); this.flow.upgrade.as_ref().unwrap().call((req, framed)) } diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs index abe396ce2..90eef5916 100644 --- a/actix-http/src/h1/encoder.rs +++ b/actix-http/src/h1/encoder.rs @@ -6,7 +6,7 @@ use std::{ slice::from_raw_parts_mut, }; -use bytes::{BufMut, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use crate::{ body::BodySize, @@ -16,6 +16,8 @@ use crate::{ helpers, ConnectionType, RequestHeadType, Response, ServiceConfig, StatusCode, Version, }; +use super::big_bytes::BigBytes; + const AVERAGE_HEADER_SIZE: usize = 30; #[derive(Debug)] @@ -49,8 +51,183 @@ pub(crate) trait MessageType: Sized { fn chunked(&self) -> bool; + fn encode_status_bigbytes(&mut self, dst: &mut BigBytes) -> io::Result<()>; fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>; + fn encode_headers_bigbytes( + &mut self, + dst: &mut BigBytes, + version: Version, + mut length: BodySize, + conn_type: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + let chunked = self.chunked(); + let mut skip_len = length != BodySize::Stream; + let camel_case = self.camel_case(); + + // Content length + if let Some(status) = self.status() { + match status { + StatusCode::CONTINUE + | StatusCode::SWITCHING_PROTOCOLS + | StatusCode::PROCESSING + | StatusCode::NO_CONTENT => { + // skip content-length and transfer-encoding headers + // see https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.1 + // and https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2 + skip_len = true; + length = BodySize::None + } + + StatusCode::NOT_MODIFIED => { + // 304 responses should never have a body but should retain a manually set + // content-length header + // see https://datatracker.ietf.org/doc/html/rfc7232#section-4.1 + skip_len = false; + length = BodySize::None; + } + + _ => {} + } + } + + match length { + BodySize::Stream => { + if chunked { + skip_len = true; + if camel_case { + dst.extend_from_slice(b"\r\nTransfer-Encoding: chunked\r\n") + } else { + dst.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") + } + } else { + skip_len = false; + dst.extend_from_slice(b"\r\n"); + } + } + BodySize::Sized(0) if camel_case => dst.extend_from_slice(b"\r\nContent-Length: 0\r\n"), + BodySize::Sized(0) => dst.extend_from_slice(b"\r\ncontent-length: 0\r\n"), + BodySize::Sized(len) => { + helpers::write_content_length(len, dst.buffer_mut(), camel_case) + } + BodySize::None => dst.extend_from_slice(b"\r\n"), + } + + // Connection + match conn_type { + ConnectionType::Upgrade => dst.extend_from_slice(b"connection: upgrade\r\n"), + ConnectionType::KeepAlive if version < Version::HTTP_11 => { + if camel_case { + dst.extend_from_slice(b"Connection: keep-alive\r\n") + } else { + dst.extend_from_slice(b"connection: keep-alive\r\n") + } + } + ConnectionType::Close if version >= Version::HTTP_11 => { + if camel_case { + dst.extend_from_slice(b"Connection: close\r\n") + } else { + dst.extend_from_slice(b"connection: close\r\n") + } + } + _ => {} + } + + // write headers + + let mut has_date = false; + + let dst = dst.buffer_mut(); + + let mut buf = dst.chunk_mut().as_mut_ptr(); + let mut remaining = dst.capacity() - dst.len(); + + // tracks bytes written since last buffer resize + // since buf is a raw pointer to a bytes container storage but is written to without the + // container's knowledge, this is used to sync the containers cursor after data is written + let mut pos = 0; + + self.write_headers(|key, value| { + match *key { + CONNECTION => return, + TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return, + DATE => has_date = true, + _ => {} + } + + let k = key.as_str().as_bytes(); + let k_len = k.len(); + + for val in value.iter() { + let v = val.as_ref(); + let v_len = v.len(); + + // key length + value length + colon + space + \r\n + let len = k_len + v_len + 4; + + if len > remaining { + // SAFETY: all the bytes written up to position "pos" are initialized + // the written byte count and pointer advancement are kept in sync + unsafe { + dst.advance_mut(pos); + } + + pos = 0; + dst.reserve(len * 2); + remaining = dst.capacity() - dst.len(); + + // re-assign buf raw pointer since it's possible that the buffer was + // reallocated and/or resized + buf = dst.chunk_mut().as_mut_ptr(); + } + + // SAFETY: on each write, it is enough to ensure that the advancement of + // the cursor matches the number of bytes written + unsafe { + if camel_case { + // use Camel-Case headers + write_camel_case(k, buf, k_len); + } else { + write_data(k, buf, k_len); + } + + buf = buf.add(k_len); + + write_data(b": ", buf, 2); + buf = buf.add(2); + + write_data(v, buf, v_len); + buf = buf.add(v_len); + + write_data(b"\r\n", buf, 2); + buf = buf.add(2); + }; + + pos += len; + remaining -= len; + } + }); + + // final cursor synchronization with the bytes container + // + // SAFETY: all the bytes written up to position "pos" are initialized + // the written byte count and pointer advancement are kept in sync + unsafe { + dst.advance_mut(pos); + } + + if !has_date { + // optimized date header, write_date_header writes its own \r\n + config.write_date_header(dst, camel_case); + } + + // end-of-headers marker + dst.extend_from_slice(b"\r\n"); + + Ok(()) + } + fn encode_headers( &mut self, dst: &mut BytesMut, @@ -263,6 +440,17 @@ impl MessageType for Response<()> { .contains(crate::message::Flags::CAMEL_CASE) } + fn encode_status_bigbytes(&mut self, dst: &mut BigBytes) -> io::Result<()> { + let head = self.head(); + let reason = head.reason().as_bytes(); + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); + + // status line + helpers::write_status_line(head.version, head.status.as_u16(), dst.buffer_mut()); + dst.extend_from_slice(reason); + Ok(()) + } + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { let head = self.head(); let reason = head.reason().as_bytes(); @@ -296,6 +484,26 @@ impl MessageType for RequestHeadType { self.extra_headers() } + fn encode_status_bigbytes(&mut self, dst: &mut BigBytes) -> io::Result<()> { + let head = self.as_ref(); + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); + write!( + helpers::MutWriter(dst.buffer_mut()), + "{} {} {}", + head.method, + head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), + match head.version { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2.0", + Version::HTTP_3 => "HTTP/3.0", + _ => return Err(io::Error::new(io::ErrorKind::Other, "unsupported version")), + } + ) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { let head = self.as_ref(); dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); @@ -323,11 +531,57 @@ impl MessageEncoder { self.te.encode(msg, buf) } + pub(super) fn encode_chunk_bigbytes( + &mut self, + msg: Bytes, + buf: &mut BigBytes, + ) -> io::Result { + self.te.encode_bigbytes(msg, buf) + } + /// Encode EOF. pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { self.te.encode_eof(buf) } + pub(super) fn encode_eof_bigbytes(&mut self, buf: &mut BigBytes) -> io::Result<()> { + self.te.encode_eof_bigbytes(buf) + } + + /// Encode message. + pub(super) fn encode_bigbytes( + &mut self, + dst: &mut BigBytes, + message: &mut T, + head: bool, + stream: bool, + version: Version, + length: BodySize, + conn_type: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + // transfer encoding + if !head { + self.te = match length { + BodySize::Sized(0) => TransferEncoding::empty(), + BodySize::Sized(len) => TransferEncoding::length(len), + BodySize::Stream => { + if message.chunked() && !stream { + TransferEncoding::chunked() + } else { + TransferEncoding::eof() + } + } + BodySize::None => TransferEncoding::empty(), + }; + } else { + self.te = TransferEncoding::empty(); + } + + message.encode_status_bigbytes(dst)?; + message.encode_headers_bigbytes(dst, version, length, conn_type, config) + } + /// Encode message. pub fn encode( &mut self, @@ -414,6 +668,51 @@ impl TransferEncoding { } } + #[inline] + /// Encode message. Return `EOF` state of encoder + pub(super) fn encode_bigbytes(&mut self, msg: Bytes, buf: &mut BigBytes) -> io::Result { + match self.kind { + TransferEncodingKind::Eof => { + let eof = msg.is_empty(); + buf.put_bytes(msg); + Ok(eof) + } + TransferEncodingKind::Chunked(ref mut eof) => { + if *eof { + return Ok(true); + } + + if msg.is_empty() { + *eof = true; + buf.extend_from_slice(b"0\r\n\r\n"); + } else { + writeln!(helpers::MutWriter(buf.buffer_mut()), "{:X}\r", msg.len()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + buf.reserve(msg.len() + 2); + buf.put_bytes(msg); + buf.extend_from_slice(b"\r\n"); + } + Ok(*eof) + } + TransferEncodingKind::Length(ref mut remaining) => { + if *remaining > 0 { + if msg.is_empty() { + return Ok(*remaining == 0); + } + let len = cmp::min(*remaining, msg.len() as u64); + + buf.put_bytes(msg.slice(..len as usize)); + + *remaining -= len; + Ok(*remaining == 0) + } else { + Ok(true) + } + } + } + } + /// Encode message. Return `EOF` state of encoder #[inline] pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { @@ -459,6 +758,28 @@ impl TransferEncoding { } } + /// Encode eof. Return `EOF` state of encoder + #[inline] + pub fn encode_eof_bigbytes(&mut self, buf: &mut BigBytes) -> io::Result<()> { + match self.kind { + TransferEncodingKind::Eof => Ok(()), + TransferEncodingKind::Length(rem) => { + if rem != 0 { + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "")) + } else { + Ok(()) + } + } + TransferEncodingKind::Chunked(ref mut eof) => { + if !*eof { + *eof = true; + buf.extend_from_slice(b"0\r\n\r\n"); + } + Ok(()) + } + } + } + /// Encode eof. Return `EOF` state of encoder #[inline] pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { diff --git a/actix-http/src/h1/mod.rs b/actix-http/src/h1/mod.rs index 9e44608d8..267b20120 100644 --- a/actix-http/src/h1/mod.rs +++ b/actix-http/src/h1/mod.rs @@ -2,6 +2,7 @@ use bytes::{Bytes, BytesMut}; +mod big_bytes; mod chunked; mod client; mod codec;