From 0648ad6f33d463bbde2952e648a4965b87b3d7e1 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 11 Jan 2018 15:26:46 -0800 Subject: [PATCH] fix implicit chunked encoding --- src/encoding.rs | 81 +++++++++++++++++++++----------- tests/test_server.rs | 108 +++++++++++++++++++++++++++++++------------ 2 files changed, 133 insertions(+), 56 deletions(-) diff --git a/src/encoding.rs b/src/encoding.rs index 90b084141..1e2a4c726 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -369,33 +369,8 @@ impl PayloadEncoder { resp.headers_mut().remove(CONTENT_ENCODING); } TransferEncoding::eof(buf) - } else if resp.chunked() { - resp.headers_mut().remove(CONTENT_LENGTH); - if version != Version::HTTP_11 { - error!("Chunked transfer encoding is forbidden for {:?}", version); - } - if version == Version::HTTP_2 { - resp.headers_mut().remove(TRANSFER_ENCODING); - TransferEncoding::eof(buf) - } else { - resp.headers_mut().insert( - TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - TransferEncoding::chunked(buf) - } - } else if let Some(len) = resp.headers().get(CONTENT_LENGTH) { - // Content-Length - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - TransferEncoding::length(len, buf) - } else { - debug!("illegal Content-Length: {:?}", len); - TransferEncoding::eof(buf) - } - } else { - TransferEncoding::eof(buf) - } } else { - TransferEncoding::eof(buf) + PayloadEncoder::streaming_encoding(buf, version, resp) } } }; @@ -414,6 +389,60 @@ impl PayloadEncoder { } ) } + + fn streaming_encoding(buf: SharedBytes, version: Version, + resp: &mut HttpResponse) -> TransferEncoding { + if resp.chunked() { + // Enable transfer encoding + resp.headers_mut().remove(CONTENT_LENGTH); + if version == Version::HTTP_2 { + resp.headers_mut().remove(TRANSFER_ENCODING); + TransferEncoding::eof(buf) + } else { + resp.headers_mut().insert( + TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + TransferEncoding::chunked(buf) + } + } else { + // if Content-Length is specified, then use it as length hint + let (len, chunked) = + if let Some(len) = resp.headers().get(CONTENT_LENGTH) { + // Content-Length + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + (Some(len), false) + } else { + error!("illegal Content-Length: {:?}", len); + (None, false) + } + } else { + error!("illegal Content-Length: {:?}", len); + (None, false) + } + } else { + (None, true) + }; + + if !chunked { + if let Some(len) = len { + TransferEncoding::length(len, buf) + } else { + TransferEncoding::eof(buf) + } + } else { + // Enable transfer encoding + resp.headers_mut().remove(CONTENT_LENGTH); + if version == Version::HTTP_2 { + resp.headers_mut().remove(TRANSFER_ENCODING); + TransferEncoding::eof(buf) + } else { + resp.headers_mut().insert( + TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + TransferEncoding::chunked(buf) + } + } + } + } } impl PayloadEncoder { diff --git a/tests/test_server.rs b/tests/test_server.rs index e8d58d751..0a6eb4693 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -16,7 +16,8 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use flate2::Compression; use flate2::write::{GzEncoder, DeflateEncoder, DeflateDecoder}; use brotli2::write::{BrotliEncoder, BrotliDecoder}; -use futures::Future; +use futures::{Future, Stream}; +use futures::stream::once; use h2::client; use bytes::{Bytes, BytesMut, BufMut}; use http::Request; @@ -113,6 +114,41 @@ fn test_body_gzip() { assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } +#[test] +fn test_body_streaming_implicit() { + let srv = test::TestServer::new( + |app| app.handler(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + httpcodes::HTTPOk.build() + .content_encoding(headers::ContentEncoding::Gzip) + .body(Body::Streaming(Box::new(body)))})); + + let mut res = reqwest::get(&srv.url("/")).unwrap(); + assert!(res.status().is_success()); + let mut bytes = BytesMut::with_capacity(2048).writer(); + let _ = res.copy_to(&mut bytes); + let bytes = bytes.into_inner(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_body_streaming_explicit() { + let srv = test::TestServer::new( + |app| app.handler(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + httpcodes::HTTPOk.build() + .chunked() + .content_encoding(headers::ContentEncoding::Gzip) + .body(Body::Streaming(Box::new(body)))})); + + let mut res = reqwest::get(&srv.url("/")).unwrap(); + assert!(res.status().is_success()); + let mut bytes = BytesMut::with_capacity(2048).writer(); + let _ = res.copy_to(&mut bytes); + let bytes = bytes.into_inner(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + #[test] fn test_body_deflate() { let srv = test::TestServer::new( @@ -153,35 +189,6 @@ fn test_body_brotli() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[test] -fn test_h2() { - let srv = test::TestServer::new(|app| app.handler(httpcodes::HTTPOk)); - let addr = srv.addr(); - - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let tcp = TcpStream::connect(&addr, &handle); - - let tcp = tcp.then(|res| { - client::handshake(res.unwrap()) - }).then(move |res| { - let (mut client, h2) = res.unwrap(); - - let request = Request::builder() - .uri(format!("https://{}/", addr).as_str()) - .body(()) - .unwrap(); - let (response, _) = client.send_request(request, false).unwrap(); - - // Spawn a task to run the conn... - handle.spawn(h2.map_err(|e| println!("GOT ERR={:?}", e))); - - response - }); - let resp = core.run(tcp).unwrap(); - assert_eq!(resp.status(), StatusCode::OK); -} - #[test] fn test_gzip_encoding() { let srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { @@ -260,6 +267,47 @@ fn test_brotli_encoding() { assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } +#[test] +fn test_h2() { + let srv = test::TestServer::new(|app| app.handler(|_|{ + httpcodes::HTTPOk.build().body(STR) + })); + let addr = srv.addr(); + + let mut core = Core::new().unwrap(); + let handle = core.handle(); + let tcp = TcpStream::connect(&addr, &handle); + + let tcp = tcp.then(|res| { + client::handshake(res.unwrap()) + }).then(move |res| { + let (mut client, h2) = res.unwrap(); + + let request = Request::builder() + .uri(format!("https://{}/", addr).as_str()) + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, false).unwrap(); + + // Spawn a task to run the conn... + handle.spawn(h2.map_err(|e| println!("GOT ERR={:?}", e))); + + response.and_then(|response| { + assert_eq!(response.status(), StatusCode::OK); + + let (_, body) = response.into_parts(); + + body.fold(BytesMut::new(), |mut b, c| -> Result<_, h2::Error> { + b.extend(c); + Ok(b) + }) + }) + }); + let res = core.run(tcp).unwrap(); + + assert_eq!(res, Bytes::from_static(STR.as_ref())); +} + #[test] fn test_application() { let srv = test::TestServer::with_factory(