From 880b863f958fdbead8f604fd935e928eb442120d Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Sun, 7 Mar 2021 10:33:16 -0800 Subject: [PATCH] fix h1 client for handling expect header request (#2049) --- actix-http/src/client/h1proto.rs | 83 ++++++++++++++++++++++++-------- actix-http/tests/test_client.rs | 61 ++++++++++++++++++++++- 2 files changed, 121 insertions(+), 23 deletions(-) diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index 082c4b8e2..d2db18cec 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -7,13 +7,15 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; use bytes::buf::BufMut; use bytes::{Bytes, BytesMut}; use futures_core::Stream; -use futures_util::future::poll_fn; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{future::poll_fn, SinkExt, StreamExt}; use crate::error::PayloadError; use crate::h1; use crate::header::HeaderMap; -use crate::http::header::{IntoHeaderValue, HOST}; +use crate::http::{ + header::{IntoHeaderValue, EXPECT, HOST}, + StatusCode, +}; use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::{Payload, PayloadStream}; @@ -66,33 +68,72 @@ where io: Some(io), }; - // create Framed and send request - let mut framed_inner = Framed::new(io, h1::ClientCodec::default()); - framed_inner.send((head, body.size()).into()).await?; + // create Framed and prepare sending request + let mut framed = Framed::new(io, h1::ClientCodec::default()); - // send request body - match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => {} - _ => send_body(body, Pin::new(&mut framed_inner)).await?, - }; + // Check EXPECT header and enable expect handle flag accordingly. + // + // RFC: https://tools.ietf.org/html/rfc7231#section-5.1.1 + let is_expect = if head.as_ref().headers.contains_key(EXPECT) { + match body.size() { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => { + let pin_framed = Pin::new(&mut framed); - // read response and init read body - let res = Pin::new(&mut framed_inner).into_future().await; - let (head, framed) = if let (Some(result), framed) = res { - let item = result.map_err(SendRequestError::from)?; - (item, framed) + let force_close = !pin_framed.codec_ref().keepalive(); + release_connection(pin_framed, force_close); + + // TODO: use a new variant or a new type better describing error violate + // `Requirements for clients` session of above RFC + return Err(SendRequestError::Connect(ConnectError::Disconnected)); + } + _ => true, + } } else { - return Err(SendRequestError::from(ConnectError::Disconnected)); + false }; - match framed.codec_ref().message_type() { + framed.send((head, body.size()).into()).await?; + + let mut pin_framed = Pin::new(&mut framed); + + // special handle for EXPECT request. + let (do_send, mut res_head) = if is_expect { + let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx)) + .await + .ok_or(ConnectError::Disconnected)??; + + // return response head in case status code is not continue + // and current head would be used as final response head. + (head.status == StatusCode::CONTINUE, Some(head)) + } else { + (true, None) + }; + + if do_send { + // send request body + match body.size() { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => {} + _ => send_body(body, pin_framed.as_mut()).await?, + }; + + // read response and init read body + let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx)) + .await + .ok_or(ConnectError::Disconnected)??; + + res_head = Some(head); + } + + let head = res_head.unwrap(); + + match pin_framed.codec_ref().message_type() { h1::MessageType::None => { - let force_close = !framed.codec_ref().keepalive(); - release_connection(framed, force_close); + let force_close = !pin_framed.codec_ref().keepalive(); + release_connection(pin_framed, force_close); Ok((head, Payload::None)) } _ => { - let pl: PayloadStream = PlStream::new(framed_inner).boxed_local(); + let pl: PayloadStream = Box::pin(PlStream::new(framed)); Ok((head, pl.into())) } } diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs index 91b2412f4..a50f2404d 100644 --- a/actix-http/tests/test_client.rs +++ b/actix-http/tests/test_client.rs @@ -1,8 +1,13 @@ -use actix_http::{http, HttpService, Request, Response}; +use actix_http::{ + error, http, http::StatusCode, HttpMessage, HttpService, Request, Response, +}; use actix_http_test::test_server; use actix_service::ServiceFactoryExt; use bytes::Bytes; -use futures_util::future::{self, ok}; +use futures_util::{ + future::{self, ok}, + StreamExt, +}; const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ Hello World Hello World Hello World Hello World Hello World \ @@ -88,3 +93,55 @@ async fn test_with_query_parameter() { let response = request.send().await.unwrap(); assert!(response.status().is_success()); } + +#[actix_rt::test] +async fn test_h1_expect() { + let srv = test_server(move || { + HttpService::build() + .expect(|req: Request| async { + if req.headers().contains_key("AUTH") { + Ok(req) + } else { + Err(error::ErrorExpectationFailed("expect failed")) + } + }) + .h1(|req: Request| async move { + let (_, mut body) = req.into_parts(); + let mut buf = Vec::new(); + while let Some(Ok(chunk)) = body.next().await { + buf.extend_from_slice(&chunk); + } + let str = std::str::from_utf8(&buf).unwrap(); + assert_eq!(str, "expect body"); + + Ok::<_, ()>(Response::Ok().finish()) + }) + .tcp() + }) + .await; + + // test expect without payload. + let request = srv + .request(http::Method::GET, srv.url("/")) + .insert_header(("Expect", "100-continue")); + + let response = request.send().await; + assert!(response.is_err()); + + // test expect would fail to continue + let request = srv + .request(http::Method::GET, srv.url("/")) + .insert_header(("Expect", "100-continue")); + + let response = request.send_body("expect body").await.unwrap(); + assert_eq!(response.status(), StatusCode::EXPECTATION_FAILED); + + // test exepct would continue + let request = srv + .request(http::Method::GET, srv.url("/")) + .insert_header(("Expect", "100-continue")) + .insert_header(("AUTH", "996")); + + let response = request.send_body("expect body").await.unwrap(); + assert!(response.status().is_success()); +}