From 69dd1a9bd643d8eb4f2689b6332eb16de4e06e90 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Mon, 15 Mar 2021 19:56:23 -0700 Subject: [PATCH] Remove ConnectionLifetime trait. Simplify Acquired handling (#2072) --- actix-http/src/client/connection.rs | 26 +------ actix-http/src/client/h1proto.rs | 117 +++++++++++++--------------- actix-http/src/client/h2proto.rs | 22 +++--- actix-http/src/client/pool.rs | 18 ++--- actix-http/src/h1/expect.rs | 2 - actix-http/src/h1/upgrade.rs | 2 - src/app_service.rs | 1 - src/resource.rs | 1 - src/scope.rs | 1 - 9 files changed, 76 insertions(+), 114 deletions(-) diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index 97ecd0515..9fb5c58cc 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -94,14 +94,6 @@ pub trait Connection { >; } -pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { - /// Close connection - fn close(self: Pin<&mut Self>); - - /// Release connection to the connection pool - fn release(self: Pin<&mut Self>); -} - #[doc(hidden)] /// HTTP client connection pub struct IoConnection @@ -110,7 +102,7 @@ where { io: Option>, created: time::Instant, - pool: Option>, + pool: Acquired, } impl fmt::Debug for IoConnection @@ -130,7 +122,7 @@ impl IoConnection { pub(crate) fn new( io: ConnectionType, created: time::Instant, - pool: Option>, + pool: Acquired, ) -> Self { IoConnection { pool, @@ -139,13 +131,9 @@ impl IoConnection { } } - pub(crate) fn into_inner(self) -> (ConnectionType, time::Instant) { - (self.io.unwrap(), self.created) - } - #[cfg(test)] pub(crate) fn into_parts(self) -> (ConnectionType, time::Instant, Acquired) { - (self.io.unwrap(), self.created, self.pool.unwrap()) + (self.io.unwrap(), self.created, self.pool) } async fn send_request>( @@ -173,13 +161,7 @@ impl IoConnection { match self.io.take().unwrap() { ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await, ConnectionType::H2(io) => { - if let Some(mut pool) = self.pool.take() { - pool.release(IoConnection::new( - ConnectionType::H2(io), - self.created, - None, - )); - } + self.pool.release(ConnectionType::H2(io), self.created); Err(SendRequestError::TunnelNotSupported) } } diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index d2db18cec..cd0b45734 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -7,7 +7,7 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; use bytes::buf::BufMut; use bytes::{Bytes, BytesMut}; use futures_core::Stream; -use futures_util::{future::poll_fn, SinkExt, StreamExt}; +use futures_util::{future::poll_fn, SinkExt}; use crate::error::PayloadError; use crate::h1; @@ -19,7 +19,7 @@ use crate::http::{ use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::{Payload, PayloadStream}; -use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; +use super::connection::ConnectionType; use super::error::{ConnectError, SendRequestError}; use super::pool::Acquired; use crate::body::{BodySize, MessageBody}; @@ -29,7 +29,7 @@ pub(crate) async fn send_request( mut head: RequestHeadType, body: B, created: time::Instant, - pool: Option>, + acquired: Acquired, ) -> Result<(ResponseHead, Payload), SendRequestError> where T: AsyncRead + AsyncWrite + Unpin + 'static, @@ -42,9 +42,9 @@ where if let Some(host) = head.as_ref().uri.host() { let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); - let _ = match head.as_ref().uri.port_u16() { - None | Some(80) | Some(443) => write!(wrt, "{}", host), - Some(port) => write!(wrt, "{}:{}", host, port), + match head.as_ref().uri.port_u16() { + None | Some(80) | Some(443) => write!(wrt, "{}", host)?, + Some(port) => write!(wrt, "{}:{}", host, port)?, }; match wrt.get_mut().split().freeze().try_into_value() { @@ -64,7 +64,7 @@ where let io = H1Connection { created, - pool, + acquired, io: Some(io), }; @@ -77,10 +77,8 @@ where let is_expect = if head.as_ref().headers.contains_key(EXPECT) { match body.size() { BodySize::None | BodySize::Empty | BodySize::Sized(0) => { - let pin_framed = Pin::new(&mut framed); - - let force_close = !pin_framed.codec_ref().keepalive(); - release_connection(pin_framed, force_close); + let keep_alive = framed.codec_ref().keepalive(); + framed.io_mut().on_release(keep_alive); // TODO: use a new variant or a new type better describing error violate // `Requirements for clients` session of above RFC @@ -128,8 +126,9 @@ where match pin_framed.codec_ref().message_type() { h1::MessageType::None => { - let force_close = !pin_framed.codec_ref().keepalive(); - release_connection(pin_framed, force_close); + let keep_alive = pin_framed.codec_ref().keepalive(); + pin_framed.io_mut().on_release(keep_alive); + Ok((head, Payload::None)) } _ => { @@ -151,12 +150,11 @@ where framed.send((head, BodySize::None).into()).await?; // read response - if let (Some(result), framed) = framed.into_future().await { - let head = result.map_err(SendRequestError::from)?; - Ok((head, framed)) - } else { - Err(SendRequestError::from(ConnectError::Disconnected)) - } + let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx)) + .await + .ok_or(ConnectError::Disconnected)??; + + Ok((head, framed)) } /// send request body to the peer @@ -165,7 +163,7 @@ pub(crate) async fn send_body( mut framed: Pin<&mut Framed>, ) -> Result<(), SendRequestError> where - T: ConnectionLifetime + Unpin, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { actix_rt::pin!(body); @@ -200,7 +198,7 @@ where } } - SinkExt::flush(Pin::into_inner(framed)).await?; + SinkExt::flush(framed.get_mut()).await?; Ok(()) } @@ -208,41 +206,37 @@ where /// HTTP client connection pub struct H1Connection where - T: AsyncWrite + Unpin + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { /// T should be `Unpin` io: Option, created: time::Instant, - pool: Option>, + acquired: Acquired, } -impl ConnectionLifetime for H1Connection +impl H1Connection where T: AsyncRead + AsyncWrite + Unpin + 'static, { + fn on_release(&mut self, keep_alive: bool) { + if keep_alive { + self.release(); + } else { + self.close(); + } + } + /// Close connection - fn close(mut self: Pin<&mut Self>) { - if let Some(mut pool) = self.pool.take() { - if let Some(io) = self.io.take() { - pool.close(IoConnection::new( - ConnectionType::H1(io), - self.created, - None, - )); - } + fn close(&mut self) { + if let Some(io) = self.io.take() { + self.acquired.close(ConnectionType::H1(io)); } } /// Release this connection to the connection pool - fn release(mut self: Pin<&mut Self>) { - if let Some(mut pool) = self.pool.take() { - if let Some(io) = self.io.take() { - pool.release(IoConnection::new( - ConnectionType::H1(io), - self.created, - None, - )); - } + fn release(&mut self) { + if let Some(io) = self.io.take() { + self.acquired.release(ConnectionType::H1(io), self.created); } } } @@ -282,13 +276,19 @@ impl AsyncWrite for H1Connection } #[pin_project::pin_project] -pub(crate) struct PlStream { +pub(crate) struct PlStream +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ #[pin] - framed: Option>, + framed: Option, h1::ClientPayloadCodec>>, } -impl PlStream { - fn new(framed: Framed) -> Self { +impl PlStream +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn new(framed: Framed, h1::ClientCodec>) -> Self { let framed = framed.into_map_codec(|codec| codec.into_payload_codec()); PlStream { @@ -297,24 +297,26 @@ impl PlStream { } } -impl Stream for PlStream { +impl Stream for PlStream +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ type Item = Result; fn poll_next( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let mut this = self.project(); + let mut framed = self.project().framed.as_pin_mut().unwrap(); - match this.framed.as_mut().as_pin_mut().unwrap().next_item(cx)? { + match framed.as_mut().next_item(cx)? { Poll::Pending => Poll::Pending, Poll::Ready(Some(chunk)) => { if let Some(chunk) = chunk { Poll::Ready(Some(Ok(chunk))) } else { - let framed = this.framed.as_mut().as_pin_mut().unwrap(); - let force_close = !framed.codec_ref().keepalive(); - release_connection(framed, force_close); + let keep_alive = framed.codec_ref().keepalive(); + framed.io_mut().on_release(keep_alive); Poll::Ready(None) } } @@ -322,14 +324,3 @@ impl Stream for PlStream { } } } - -fn release_connection(framed: Pin<&mut Framed>, force_close: bool) -where - T: ConnectionLifetime, -{ - if !force_close && framed.is_read_buf_empty() && framed.is_write_buf_empty() { - framed.io_pin().release() - } else { - framed.io_pin().close() - } -} diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 7292972de..82e81c7ff 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -17,7 +17,7 @@ use crate::message::{RequestHeadType, ResponseHead}; use crate::payload::Payload; use super::config::ConnectorConfig; -use super::connection::{ConnectionType, IoConnection}; +use super::connection::ConnectionType; use super::error::SendRequestError; use super::pool::Acquired; use crate::client::connection::H2Connection; @@ -27,7 +27,7 @@ pub(crate) async fn send_request( head: RequestHeadType, body: B, created: time::Instant, - pool: Option>, + acquired: Acquired, ) -> Result<(ResponseHead, Payload), SendRequestError> where T: AsyncRead + AsyncWrite + Unpin + 'static, @@ -103,13 +103,13 @@ where let res = poll_fn(|cx| io.poll_ready(cx)).await; if let Err(e) = res { - release(io, pool, created, e.is_io()); + release(io, acquired, created, e.is_io()); return Err(SendRequestError::from(e)); } let resp = match io.send_request(req, eof) { Ok((fut, send)) => { - release(io, pool, created, false); + release(io, acquired, created, false); if !eof { send_body(body, send).await?; @@ -117,7 +117,7 @@ where fut.await.map_err(SendRequestError::from)? } Err(e) => { - release(io, pool, created, e.is_io()); + release(io, acquired, created, e.is_io()); return Err(e.into()); } }; @@ -181,16 +181,14 @@ async fn send_body( /// release SendRequest object fn release( io: H2Connection, - pool: Option>, + acquired: Acquired, created: time::Instant, close: bool, ) { - if let Some(mut pool) = pool { - if close { - pool.close(IoConnection::new(ConnectionType::H2(io), created, None)); - } else { - pool.release(IoConnection::new(ConnectionType::H2(io), created, None)); - } + if close { + acquired.close(ConnectionType::H2(io)); + } else { + acquired.release(ConnectionType::H2(io), created); } } diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs index b5b4a30d3..79adcf3e2 100644 --- a/actix-http/src/client/pool.rs +++ b/actix-http/src/client/pool.rs @@ -217,7 +217,7 @@ where // construct acquired. It's used to put Io type back to pool/ close the Io type. // permit is carried with the whole lifecycle of Acquired. - let acquired = Some(Acquired { key, inner, permit }); + let acquired = Acquired { key, inner, permit }; // match the connection and spawn new one if did not get anything. match conn { @@ -235,7 +235,7 @@ where acquired, )) } else { - let config = &acquired.as_ref().unwrap().inner.config; + let config = &acquired.inner.config; let (sender, connection) = handshake(io, config).await?; Ok(IoConnection::new( ConnectionType::H2(H2Connection::new(sender, connection)), @@ -346,14 +346,12 @@ where Io: AsyncRead + AsyncWrite + Unpin + 'static, { /// Close the IO. - pub(crate) fn close(&mut self, conn: IoConnection) { - let (conn, _) = conn.into_inner(); + pub(crate) fn close(&self, conn: ConnectionType) { self.inner.close(conn); } /// Release IO back into pool. - pub(crate) fn release(&mut self, conn: IoConnection) { - let (io, created) = conn.into_inner(); + pub(crate) fn release(&self, conn: ConnectionType, created: Instant) { let Acquired { key, inner, .. } = self; inner @@ -362,12 +360,12 @@ where .entry(key.clone()) .or_insert_with(VecDeque::new) .push_back(PooledConnection { - conn: io, + conn, created, used: Instant::now(), }); - let _ = &mut self.permit; + let _ = &self.permit; } } @@ -447,8 +445,8 @@ mod test { where T: AsyncRead + AsyncWrite + Unpin + 'static, { - let (conn, created, mut acquired) = conn.into_parts(); - acquired.release(IoConnection::new(conn, created, None)); + let (conn, created, acquired) = conn.into_parts(); + acquired.release(conn, created); } #[actix_rt::test] diff --git a/actix-http/src/h1/expect.rs b/actix-http/src/h1/expect.rs index 65856edf6..5015069bb 100644 --- a/actix-http/src/h1/expect.rs +++ b/actix-http/src/h1/expect.rs @@ -1,5 +1,3 @@ -use std::task::Poll; - use actix_service::{Service, ServiceFactory}; use futures_util::future::{ready, Ready}; diff --git a/actix-http/src/h1/upgrade.rs b/actix-http/src/h1/upgrade.rs index fbfbc83c1..e57ea8ae9 100644 --- a/actix-http/src/h1/upgrade.rs +++ b/actix-http/src/h1/upgrade.rs @@ -1,5 +1,3 @@ -use std::task::Poll; - use actix_codec::Framed; use actix_service::{Service, ServiceFactory}; use futures_core::future::LocalBoxFuture; diff --git a/src/app_service.rs b/src/app_service.rs index 9b4ae3354..be4ccf22f 100644 --- a/src/app_service.rs +++ b/src/app_service.rs @@ -1,6 +1,5 @@ use std::cell::RefCell; use std::rc::Rc; -use std::task::Poll; use actix_http::{Extensions, Request, Response}; use actix_router::{Path, ResourceDef, Router, Url}; diff --git a/src/resource.rs b/src/resource.rs index 944beeefa..1a5619de6 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -2,7 +2,6 @@ use std::cell::RefCell; use std::fmt; use std::future::Future; use std::rc::Rc; -use std::task::Poll; use actix_http::{Error, Extensions, Response}; use actix_router::IntoPattern; diff --git a/src/scope.rs b/src/scope.rs index dd02501b0..e96f5b9e8 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -2,7 +2,6 @@ use std::cell::RefCell; use std::fmt; use std::future::Future; use std::rc::Rc; -use std::task::Poll; use actix_http::Extensions; use actix_router::{ResourceDef, Router};