1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-12-17 21:56:38 +00:00

actix-http: play with allowing bytes passthrough in the h1 encoder

This commit is contained in:
asonix 2024-05-18 12:30:19 -05:00
parent fff45b28f4
commit 35cadbbe0b
5 changed files with 478 additions and 29 deletions

View file

@ -0,0 +1,111 @@
use std::collections::VecDeque;
use bytes::{Buf, BufMut, Bytes, BytesMut};
const SIXTYFOUR_KB: usize = 1024 * 64;
pub(super) struct BigBytes {
buffer: BytesMut,
frozen: VecDeque<Bytes>,
frozen_len: usize,
}
impl BigBytes {
pub(super) fn with_capacity(capacity: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(capacity),
frozen: VecDeque::default(),
frozen_len: 0,
}
}
// Clear the internal queue and buffer, resetting length to zero
pub(super) fn clear(&mut self) {
std::mem::take(&mut self.frozen);
self.frozen_len = 0;
self.buffer.clear();
}
// Return a mutable reference to the underlying buffer. This should only be used when dealing
// with small allocations (e.g. writing headers)
pub(super) fn buffer_mut(&mut self) -> &mut BytesMut {
&mut self.buffer
}
// Reserve the requested size, if fewer than 64KB
pub(super) fn reserve(&mut self, count: usize) {
if count < SIXTYFOUR_KB {
self.buffer.reserve(count);
}
}
pub(super) fn total_len(&mut self) -> usize {
self.frozen_len + self.buffer.len()
}
pub(super) fn is_empty(&self) -> bool {
self.frozen_len == 0 && self.buffer.is_empty()
}
// Add the `bytes` to the internal structure. If `bytes` exceeds 64KB, it is pushed into a
// queue, otherwise, it is added to a buffer.
pub(super) fn put_bytes(&mut self, bytes: Bytes) {
if bytes.len() < SIXTYFOUR_KB {
self.buffer.extend_from_slice(&bytes);
} else {
if !self.buffer.is_empty() {
let current = self.buffer.split().freeze();
self.frozen_len += current.len();
self.frozen.push_back(current);
}
self.frozen_len += bytes.len();
self.frozen.push_back(bytes);
}
}
// Put a slice into the internal structure. This is always added to the internal buffer
pub(super) fn extend_from_slice(&mut self, slice: &[u8]) {
self.buffer.extend_from_slice(slice);
}
// Returns a slice of the frontmost buffer
pub(super) fn front_slice(&self) -> &[u8] {
if let Some(front) = self.frozen.front() {
&front
} else {
&self.buffer
}
}
// Advances the first buffer by `count` bytes. If the first buffer is advanced to completion,
// it is popped from the queue
pub(super) fn advance(&mut self, count: usize) {
if let Some(front) = self.frozen.front_mut() {
front.advance(count);
if front.is_empty() {
self.frozen.pop_front();
}
self.frozen_len -= count;
} else {
self.buffer.advance(count);
}
}
// Drain the BibBytes, writing everything into the provided BytesMut
pub(super) fn write_to(&mut self, dst: &mut BytesMut) {
dst.reserve(self.total_len());
for buf in &self.frozen {
dst.put_slice(buf);
}
dst.put_slice(&self.buffer.split());
self.frozen_len = 0;
std::mem::take(&mut self.frozen);
}
}

View file

@ -6,6 +6,7 @@ use http::{Method, Version};
use tokio_util::codec::{Decoder, Encoder};
use super::{
big_bytes::BigBytes,
decoder::{self, PayloadDecoder, PayloadItem, PayloadType},
encoder, Message, MessageType,
};
@ -146,14 +147,12 @@ impl Decoder for Codec {
}
}
impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
type Error = io::Error;
fn encode(
impl Codec {
pub(super) fn encode_bigbytes(
&mut self,
item: Message<(Response<()>, BodySize)>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
dst: &mut BigBytes,
) -> std::io::Result<()> {
match item {
Message::Item((mut res, length)) => {
// set response version
@ -171,7 +170,7 @@ impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
};
// encode message
self.encoder.encode(
self.encoder.encode_bigbytes(
dst,
&mut res,
self.flags.contains(Flags::HEAD),
@ -184,11 +183,11 @@ impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
}
Message::Chunk(Some(bytes)) => {
self.encoder.encode_chunk(bytes.as_ref(), dst)?;
self.encoder.encode_chunk_bigbytes(bytes, dst)?;
}
Message::Chunk(None) => {
self.encoder.encode_eof(dst)?;
self.encoder.encode_eof_bigbytes(dst)?;
}
}
@ -196,6 +195,23 @@ impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
}
}
impl Encoder<Message<(Response<()>, BodySize)>> for Codec {
type Error = io::Error;
fn encode(
&mut self,
item: Message<(Response<()>, BodySize)>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
let mut bigbytes = BigBytes::with_capacity(1024 * 8);
self.encode_bigbytes(item, &mut bigbytes)?;
bigbytes.write_to(dst);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -12,14 +12,15 @@ use actix_codec::{Framed, FramedParts};
use actix_rt::time::sleep_until;
use actix_service::Service;
use bitflags::bitflags;
use bytes::{Buf, BytesMut};
use bytes::BytesMut;
use futures_core::ready;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{Decoder as _, Encoder as _};
use tokio_util::codec::Decoder as _;
use tracing::{error, trace};
use super::{
big_bytes::BigBytes,
codec::Codec,
decoder::MAX_BUFFER_SIZE,
payload::{Payload, PayloadSender, PayloadStatus},
@ -165,7 +166,7 @@ pin_project! {
pub(super) io: Option<T>,
read_buf: BytesMut,
write_buf: BytesMut,
write_buf: BigBytes,
codec: Codec,
}
}
@ -277,7 +278,7 @@ where
io: Some(io),
read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
write_buf: BigBytes::with_capacity(HW_BUFFER_SIZE),
codec: Codec::new(config),
},
},
@ -329,20 +330,17 @@ where
let InnerDispatcherProj { io, write_buf, .. } = self.project();
let mut io = Pin::new(io.as_mut().unwrap());
let len = write_buf.len();
let mut written = 0;
while written < len {
match io.as_mut().poll_write(cx, &write_buf[written..])? {
while write_buf.total_len() > 0 {
match io.as_mut().poll_write(cx, write_buf.front_slice())? {
Poll::Ready(0) => {
println!("WRITE ZERO");
error!("write zero; closing");
return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, "")));
}
Poll::Ready(n) => written += n,
Poll::Ready(n) => write_buf.advance(n),
Poll::Pending => {
write_buf.advance(written);
return Poll::Pending;
}
}
@ -365,7 +363,7 @@ where
let size = body.size();
this.codec
.encode(Message::Item((res, size)), this.write_buf)
.encode_bigbytes(Message::Item((res, size)), this.write_buf)
.map_err(|err| {
if let Some(mut payload) = this.payload.take() {
payload.set_error(PayloadError::Incomplete(None));
@ -493,15 +491,16 @@ where
StateProj::SendPayload { mut body } => {
// keep populate writer buffer until buffer size limit hit,
// get blocked or finished.
while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
while this.write_buf.total_len() < super::payload::MAX_BUFFER_SIZE {
match body.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(item))) => {
this.codec
.encode(Message::Chunk(Some(item)), this.write_buf)?;
.encode_bigbytes(Message::Chunk(Some(item)), this.write_buf)?;
}
Poll::Ready(None) => {
this.codec.encode(Message::Chunk(None), this.write_buf)?;
this.codec
.encode_bigbytes(Message::Chunk(None), this.write_buf)?;
// payload stream finished.
// set state to None and handle next message
@ -532,15 +531,16 @@ where
// keep populate writer buffer until buffer size limit hit,
// get blocked or finished.
while this.write_buf.len() < super::payload::MAX_BUFFER_SIZE {
while this.write_buf.total_len() < super::payload::MAX_BUFFER_SIZE {
match body.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(item))) => {
this.codec
.encode(Message::Chunk(Some(item)), this.write_buf)?;
.encode_bigbytes(Message::Chunk(Some(item)), this.write_buf)?;
}
Poll::Ready(None) => {
this.codec.encode(Message::Chunk(None), this.write_buf)?;
this.codec
.encode_bigbytes(Message::Chunk(None), this.write_buf)?;
// payload stream finished
// set state to None and handle next message
@ -1027,7 +1027,7 @@ where
mem::take(this.codec),
mem::take(this.read_buf),
);
parts.write_buf = mem::take(this.write_buf);
this.write_buf.write_to(&mut parts.write_buf);
let framed = Framed::from_parts(parts);
this.flow.upgrade.as_ref().unwrap().call((req, framed))
}

View file

@ -6,7 +6,7 @@ use std::{
slice::from_raw_parts_mut,
};
use bytes::{BufMut, BytesMut};
use bytes::{BufMut, Bytes, BytesMut};
use crate::{
body::BodySize,
@ -16,6 +16,8 @@ use crate::{
helpers, ConnectionType, RequestHeadType, Response, ServiceConfig, StatusCode, Version,
};
use super::big_bytes::BigBytes;
const AVERAGE_HEADER_SIZE: usize = 30;
#[derive(Debug)]
@ -49,8 +51,183 @@ pub(crate) trait MessageType: Sized {
fn chunked(&self) -> bool;
fn encode_status_bigbytes(&mut self, dst: &mut BigBytes) -> io::Result<()>;
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;
fn encode_headers_bigbytes(
&mut self,
dst: &mut BigBytes,
version: Version,
mut length: BodySize,
conn_type: ConnectionType,
config: &ServiceConfig,
) -> io::Result<()> {
let chunked = self.chunked();
let mut skip_len = length != BodySize::Stream;
let camel_case = self.camel_case();
// Content length
if let Some(status) = self.status() {
match status {
StatusCode::CONTINUE
| StatusCode::SWITCHING_PROTOCOLS
| StatusCode::PROCESSING
| StatusCode::NO_CONTENT => {
// skip content-length and transfer-encoding headers
// see https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.1
// and https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.2
skip_len = true;
length = BodySize::None
}
StatusCode::NOT_MODIFIED => {
// 304 responses should never have a body but should retain a manually set
// content-length header
// see https://datatracker.ietf.org/doc/html/rfc7232#section-4.1
skip_len = false;
length = BodySize::None;
}
_ => {}
}
}
match length {
BodySize::Stream => {
if chunked {
skip_len = true;
if camel_case {
dst.extend_from_slice(b"\r\nTransfer-Encoding: chunked\r\n")
} else {
dst.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
}
} else {
skip_len = false;
dst.extend_from_slice(b"\r\n");
}
}
BodySize::Sized(0) if camel_case => dst.extend_from_slice(b"\r\nContent-Length: 0\r\n"),
BodySize::Sized(0) => dst.extend_from_slice(b"\r\ncontent-length: 0\r\n"),
BodySize::Sized(len) => {
helpers::write_content_length(len, dst.buffer_mut(), camel_case)
}
BodySize::None => dst.extend_from_slice(b"\r\n"),
}
// Connection
match conn_type {
ConnectionType::Upgrade => dst.extend_from_slice(b"connection: upgrade\r\n"),
ConnectionType::KeepAlive if version < Version::HTTP_11 => {
if camel_case {
dst.extend_from_slice(b"Connection: keep-alive\r\n")
} else {
dst.extend_from_slice(b"connection: keep-alive\r\n")
}
}
ConnectionType::Close if version >= Version::HTTP_11 => {
if camel_case {
dst.extend_from_slice(b"Connection: close\r\n")
} else {
dst.extend_from_slice(b"connection: close\r\n")
}
}
_ => {}
}
// write headers
let mut has_date = false;
let dst = dst.buffer_mut();
let mut buf = dst.chunk_mut().as_mut_ptr();
let mut remaining = dst.capacity() - dst.len();
// tracks bytes written since last buffer resize
// since buf is a raw pointer to a bytes container storage but is written to without the
// container's knowledge, this is used to sync the containers cursor after data is written
let mut pos = 0;
self.write_headers(|key, value| {
match *key {
CONNECTION => return,
TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => return,
DATE => has_date = true,
_ => {}
}
let k = key.as_str().as_bytes();
let k_len = k.len();
for val in value.iter() {
let v = val.as_ref();
let v_len = v.len();
// key length + value length + colon + space + \r\n
let len = k_len + v_len + 4;
if len > remaining {
// SAFETY: all the bytes written up to position "pos" are initialized
// the written byte count and pointer advancement are kept in sync
unsafe {
dst.advance_mut(pos);
}
pos = 0;
dst.reserve(len * 2);
remaining = dst.capacity() - dst.len();
// re-assign buf raw pointer since it's possible that the buffer was
// reallocated and/or resized
buf = dst.chunk_mut().as_mut_ptr();
}
// SAFETY: on each write, it is enough to ensure that the advancement of
// the cursor matches the number of bytes written
unsafe {
if camel_case {
// use Camel-Case headers
write_camel_case(k, buf, k_len);
} else {
write_data(k, buf, k_len);
}
buf = buf.add(k_len);
write_data(b": ", buf, 2);
buf = buf.add(2);
write_data(v, buf, v_len);
buf = buf.add(v_len);
write_data(b"\r\n", buf, 2);
buf = buf.add(2);
};
pos += len;
remaining -= len;
}
});
// final cursor synchronization with the bytes container
//
// SAFETY: all the bytes written up to position "pos" are initialized
// the written byte count and pointer advancement are kept in sync
unsafe {
dst.advance_mut(pos);
}
if !has_date {
// optimized date header, write_date_header writes its own \r\n
config.write_date_header(dst, camel_case);
}
// end-of-headers marker
dst.extend_from_slice(b"\r\n");
Ok(())
}
fn encode_headers(
&mut self,
dst: &mut BytesMut,
@ -263,6 +440,17 @@ impl MessageType for Response<()> {
.contains(crate::message::Flags::CAMEL_CASE)
}
fn encode_status_bigbytes(&mut self, dst: &mut BigBytes) -> io::Result<()> {
let head = self.head();
let reason = head.reason().as_bytes();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
// status line
helpers::write_status_line(head.version, head.status.as_u16(), dst.buffer_mut());
dst.extend_from_slice(reason);
Ok(())
}
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
let head = self.head();
let reason = head.reason().as_bytes();
@ -296,6 +484,26 @@ impl MessageType for RequestHeadType {
self.extra_headers()
}
fn encode_status_bigbytes(&mut self, dst: &mut BigBytes) -> io::Result<()> {
let head = self.as_ref();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
write!(
helpers::MutWriter(dst.buffer_mut()),
"{} {} {}",
head.method,
head.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
match head.version {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2.0",
Version::HTTP_3 => "HTTP/3.0",
_ => return Err(io::Error::new(io::ErrorKind::Other, "unsupported version")),
}
)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
let head = self.as_ref();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
@ -323,11 +531,57 @@ impl<T: MessageType> MessageEncoder<T> {
self.te.encode(msg, buf)
}
pub(super) fn encode_chunk_bigbytes(
&mut self,
msg: Bytes,
buf: &mut BigBytes,
) -> io::Result<bool> {
self.te.encode_bigbytes(msg, buf)
}
/// Encode EOF.
pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
self.te.encode_eof(buf)
}
pub(super) fn encode_eof_bigbytes(&mut self, buf: &mut BigBytes) -> io::Result<()> {
self.te.encode_eof_bigbytes(buf)
}
/// Encode message.
pub(super) fn encode_bigbytes(
&mut self,
dst: &mut BigBytes,
message: &mut T,
head: bool,
stream: bool,
version: Version,
length: BodySize,
conn_type: ConnectionType,
config: &ServiceConfig,
) -> io::Result<()> {
// transfer encoding
if !head {
self.te = match length {
BodySize::Sized(0) => TransferEncoding::empty(),
BodySize::Sized(len) => TransferEncoding::length(len),
BodySize::Stream => {
if message.chunked() && !stream {
TransferEncoding::chunked()
} else {
TransferEncoding::eof()
}
}
BodySize::None => TransferEncoding::empty(),
};
} else {
self.te = TransferEncoding::empty();
}
message.encode_status_bigbytes(dst)?;
message.encode_headers_bigbytes(dst, version, length, conn_type, config)
}
/// Encode message.
pub fn encode(
&mut self,
@ -414,6 +668,51 @@ impl TransferEncoding {
}
}
#[inline]
/// Encode message. Return `EOF` state of encoder
pub(super) fn encode_bigbytes(&mut self, msg: Bytes, buf: &mut BigBytes) -> io::Result<bool> {
match self.kind {
TransferEncodingKind::Eof => {
let eof = msg.is_empty();
buf.put_bytes(msg);
Ok(eof)
}
TransferEncodingKind::Chunked(ref mut eof) => {
if *eof {
return Ok(true);
}
if msg.is_empty() {
*eof = true;
buf.extend_from_slice(b"0\r\n\r\n");
} else {
writeln!(helpers::MutWriter(buf.buffer_mut()), "{:X}\r", msg.len())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
buf.reserve(msg.len() + 2);
buf.put_bytes(msg);
buf.extend_from_slice(b"\r\n");
}
Ok(*eof)
}
TransferEncodingKind::Length(ref mut remaining) => {
if *remaining > 0 {
if msg.is_empty() {
return Ok(*remaining == 0);
}
let len = cmp::min(*remaining, msg.len() as u64);
buf.put_bytes(msg.slice(..len as usize));
*remaining -= len;
Ok(*remaining == 0)
} else {
Ok(true)
}
}
}
}
/// Encode message. Return `EOF` state of encoder
#[inline]
pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
@ -459,6 +758,28 @@ impl TransferEncoding {
}
}
/// Encode eof. Return `EOF` state of encoder
#[inline]
pub fn encode_eof_bigbytes(&mut self, buf: &mut BigBytes) -> io::Result<()> {
match self.kind {
TransferEncodingKind::Eof => Ok(()),
TransferEncodingKind::Length(rem) => {
if rem != 0 {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))
} else {
Ok(())
}
}
TransferEncodingKind::Chunked(ref mut eof) => {
if !*eof {
*eof = true;
buf.extend_from_slice(b"0\r\n\r\n");
}
Ok(())
}
}
}
/// Encode eof. Return `EOF` state of encoder
#[inline]
pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {

View file

@ -2,6 +2,7 @@
use bytes::{Bytes, BytesMut};
mod big_bytes;
mod chunked;
mod client;
mod codec;