1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-06-10 17:29:36 +00:00

Remove ConnectionLifetime trait. Simplify Acquired handling (#2072)

This commit is contained in:
fakeshadow 2021-03-15 19:56:23 -07:00 committed by GitHub
parent d93314a683
commit 69dd1a9bd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 76 additions and 114 deletions

View file

@ -94,14 +94,6 @@ pub trait Connection {
>;
}
pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static {
/// Close connection
fn close(self: Pin<&mut Self>);
/// Release connection to the connection pool
fn release(self: Pin<&mut Self>);
}
#[doc(hidden)]
/// HTTP client connection
pub struct IoConnection<T>
@ -110,7 +102,7 @@ where
{
io: Option<ConnectionType<T>>,
created: time::Instant,
pool: Option<Acquired<T>>,
pool: Acquired<T>,
}
impl<T> fmt::Debug for IoConnection<T>
@ -130,7 +122,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
pub(crate) fn new(
io: ConnectionType<T>,
created: time::Instant,
pool: Option<Acquired<T>>,
pool: Acquired<T>,
) -> Self {
IoConnection {
pool,
@ -139,13 +131,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
}
}
pub(crate) fn into_inner(self) -> (ConnectionType<T>, time::Instant) {
(self.io.unwrap(), self.created)
}
#[cfg(test)]
pub(crate) fn into_parts(self) -> (ConnectionType<T>, time::Instant, Acquired<T>) {
(self.io.unwrap(), self.created, self.pool.unwrap())
(self.io.unwrap(), self.created, self.pool)
}
async fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
@ -173,13 +161,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
match self.io.take().unwrap() {
ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await,
ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() {
pool.release(IoConnection::new(
ConnectionType::H2(io),
self.created,
None,
));
}
self.pool.release(ConnectionType::H2(io), self.created);
Err(SendRequestError::TunnelNotSupported)
}
}

View file

@ -7,7 +7,7 @@ use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use bytes::buf::BufMut;
use bytes::{Bytes, BytesMut};
use futures_core::Stream;
use futures_util::{future::poll_fn, SinkExt, StreamExt};
use futures_util::{future::poll_fn, SinkExt};
use crate::error::PayloadError;
use crate::h1;
@ -19,7 +19,7 @@ use crate::http::{
use crate::message::{RequestHeadType, ResponseHead};
use crate::payload::{Payload, PayloadStream};
use super::connection::{ConnectionLifetime, ConnectionType, IoConnection};
use super::connection::ConnectionType;
use super::error::{ConnectError, SendRequestError};
use super::pool::Acquired;
use crate::body::{BodySize, MessageBody};
@ -29,7 +29,7 @@ pub(crate) async fn send_request<T, B>(
mut head: RequestHeadType,
body: B,
created: time::Instant,
pool: Option<Acquired<T>>,
acquired: Acquired<T>,
) -> Result<(ResponseHead, Payload), SendRequestError>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
@ -42,9 +42,9 @@ where
if let Some(host) = head.as_ref().uri.host() {
let mut wrt = BytesMut::with_capacity(host.len() + 5).writer();
let _ = match head.as_ref().uri.port_u16() {
None | Some(80) | Some(443) => write!(wrt, "{}", host),
Some(port) => write!(wrt, "{}:{}", host, port),
match head.as_ref().uri.port_u16() {
None | Some(80) | Some(443) => write!(wrt, "{}", host)?,
Some(port) => write!(wrt, "{}:{}", host, port)?,
};
match wrt.get_mut().split().freeze().try_into_value() {
@ -64,7 +64,7 @@ where
let io = H1Connection {
created,
pool,
acquired,
io: Some(io),
};
@ -77,10 +77,8 @@ where
let is_expect = if head.as_ref().headers.contains_key(EXPECT) {
match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => {
let pin_framed = Pin::new(&mut framed);
let force_close = !pin_framed.codec_ref().keepalive();
release_connection(pin_framed, force_close);
let keep_alive = framed.codec_ref().keepalive();
framed.io_mut().on_release(keep_alive);
// TODO: use a new variant or a new type better describing error violate
// `Requirements for clients` session of above RFC
@ -128,8 +126,9 @@ where
match pin_framed.codec_ref().message_type() {
h1::MessageType::None => {
let force_close = !pin_framed.codec_ref().keepalive();
release_connection(pin_framed, force_close);
let keep_alive = pin_framed.codec_ref().keepalive();
pin_framed.io_mut().on_release(keep_alive);
Ok((head, Payload::None))
}
_ => {
@ -151,12 +150,11 @@ where
framed.send((head, BodySize::None).into()).await?;
// read response
if let (Some(result), framed) = framed.into_future().await {
let head = result.map_err(SendRequestError::from)?;
Ok((head, framed))
} else {
Err(SendRequestError::from(ConnectError::Disconnected))
}
let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
.await
.ok_or(ConnectError::Disconnected)??;
Ok((head, framed))
}
/// send request body to the peer
@ -165,7 +163,7 @@ pub(crate) async fn send_body<T, B>(
mut framed: Pin<&mut Framed<T, h1::ClientCodec>>,
) -> Result<(), SendRequestError>
where
T: ConnectionLifetime + Unpin,
T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody,
{
actix_rt::pin!(body);
@ -200,7 +198,7 @@ where
}
}
SinkExt::flush(Pin::into_inner(framed)).await?;
SinkExt::flush(framed.get_mut()).await?;
Ok(())
}
@ -208,41 +206,37 @@ where
/// HTTP client connection
pub struct H1Connection<T>
where
T: AsyncWrite + Unpin + 'static,
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// T should be `Unpin`
io: Option<T>,
created: time::Instant,
pool: Option<Acquired<T>>,
acquired: Acquired<T>,
}
impl<T> ConnectionLifetime for H1Connection<T>
impl<T> H1Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn on_release(&mut self, keep_alive: bool) {
if keep_alive {
self.release();
} else {
self.close();
}
}
/// Close connection
fn close(mut self: Pin<&mut Self>) {
if let Some(mut pool) = self.pool.take() {
if let Some(io) = self.io.take() {
pool.close(IoConnection::new(
ConnectionType::H1(io),
self.created,
None,
));
}
fn close(&mut self) {
if let Some(io) = self.io.take() {
self.acquired.close(ConnectionType::H1(io));
}
}
/// Release this connection to the connection pool
fn release(mut self: Pin<&mut Self>) {
if let Some(mut pool) = self.pool.take() {
if let Some(io) = self.io.take() {
pool.release(IoConnection::new(
ConnectionType::H1(io),
self.created,
None,
));
}
fn release(&mut self) {
if let Some(io) = self.io.take() {
self.acquired.release(ConnectionType::H1(io), self.created);
}
}
}
@ -282,13 +276,19 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T>
}
#[pin_project::pin_project]
pub(crate) struct PlStream<Io> {
pub(crate) struct PlStream<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
#[pin]
framed: Option<Framed<Io, h1::ClientPayloadCodec>>,
framed: Option<Framed<H1Connection<Io>, h1::ClientPayloadCodec>>,
}
impl<Io: ConnectionLifetime> PlStream<Io> {
fn new(framed: Framed<Io, h1::ClientCodec>) -> Self {
impl<Io> PlStream<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn new(framed: Framed<H1Connection<Io>, h1::ClientCodec>) -> Self {
let framed = framed.into_map_codec(|codec| codec.into_payload_codec());
PlStream {
@ -297,24 +297,26 @@ impl<Io: ConnectionLifetime> PlStream<Io> {
}
}
impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
impl<Io> Stream for PlStream<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Item = Result<Bytes, PayloadError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let mut framed = self.project().framed.as_pin_mut().unwrap();
match this.framed.as_mut().as_pin_mut().unwrap().next_item(cx)? {
match framed.as_mut().next_item(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk {
Poll::Ready(Some(Ok(chunk)))
} else {
let framed = this.framed.as_mut().as_pin_mut().unwrap();
let force_close = !framed.codec_ref().keepalive();
release_connection(framed, force_close);
let keep_alive = framed.codec_ref().keepalive();
framed.io_mut().on_release(keep_alive);
Poll::Ready(None)
}
}
@ -322,14 +324,3 @@ impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
}
}
}
fn release_connection<T, U>(framed: Pin<&mut Framed<T, U>>, force_close: bool)
where
T: ConnectionLifetime,
{
if !force_close && framed.is_read_buf_empty() && framed.is_write_buf_empty() {
framed.io_pin().release()
} else {
framed.io_pin().close()
}
}

View file

@ -17,7 +17,7 @@ use crate::message::{RequestHeadType, ResponseHead};
use crate::payload::Payload;
use super::config::ConnectorConfig;
use super::connection::{ConnectionType, IoConnection};
use super::connection::ConnectionType;
use super::error::SendRequestError;
use super::pool::Acquired;
use crate::client::connection::H2Connection;
@ -27,7 +27,7 @@ pub(crate) async fn send_request<T, B>(
head: RequestHeadType,
body: B,
created: time::Instant,
pool: Option<Acquired<T>>,
acquired: Acquired<T>,
) -> Result<(ResponseHead, Payload), SendRequestError>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
@ -103,13 +103,13 @@ where
let res = poll_fn(|cx| io.poll_ready(cx)).await;
if let Err(e) = res {
release(io, pool, created, e.is_io());
release(io, acquired, created, e.is_io());
return Err(SendRequestError::from(e));
}
let resp = match io.send_request(req, eof) {
Ok((fut, send)) => {
release(io, pool, created, false);
release(io, acquired, created, false);
if !eof {
send_body(body, send).await?;
@ -117,7 +117,7 @@ where
fut.await.map_err(SendRequestError::from)?
}
Err(e) => {
release(io, pool, created, e.is_io());
release(io, acquired, created, e.is_io());
return Err(e.into());
}
};
@ -181,16 +181,14 @@ async fn send_body<B: MessageBody>(
/// release SendRequest object
fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>(
io: H2Connection,
pool: Option<Acquired<T>>,
acquired: Acquired<T>,
created: time::Instant,
close: bool,
) {
if let Some(mut pool) = pool {
if close {
pool.close(IoConnection::new(ConnectionType::H2(io), created, None));
} else {
pool.release(IoConnection::new(ConnectionType::H2(io), created, None));
}
if close {
acquired.close(ConnectionType::H2(io));
} else {
acquired.release(ConnectionType::H2(io), created);
}
}

View file

@ -217,7 +217,7 @@ where
// construct acquired. It's used to put Io type back to pool/ close the Io type.
// permit is carried with the whole lifecycle of Acquired.
let acquired = Some(Acquired { key, inner, permit });
let acquired = Acquired { key, inner, permit };
// match the connection and spawn new one if did not get anything.
match conn {
@ -235,7 +235,7 @@ where
acquired,
))
} else {
let config = &acquired.as_ref().unwrap().inner.config;
let config = &acquired.inner.config;
let (sender, connection) = handshake(io, config).await?;
Ok(IoConnection::new(
ConnectionType::H2(H2Connection::new(sender, connection)),
@ -346,14 +346,12 @@ where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Close the IO.
pub(crate) fn close(&mut self, conn: IoConnection<Io>) {
let (conn, _) = conn.into_inner();
pub(crate) fn close(&self, conn: ConnectionType<Io>) {
self.inner.close(conn);
}
/// Release IO back into pool.
pub(crate) fn release(&mut self, conn: IoConnection<Io>) {
let (io, created) = conn.into_inner();
pub(crate) fn release(&self, conn: ConnectionType<Io>, created: Instant) {
let Acquired { key, inner, .. } = self;
inner
@ -362,12 +360,12 @@ where
.entry(key.clone())
.or_insert_with(VecDeque::new)
.push_back(PooledConnection {
conn: io,
conn,
created,
used: Instant::now(),
});
let _ = &mut self.permit;
let _ = &self.permit;
}
}
@ -447,8 +445,8 @@ mod test {
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
let (conn, created, mut acquired) = conn.into_parts();
acquired.release(IoConnection::new(conn, created, None));
let (conn, created, acquired) = conn.into_parts();
acquired.release(conn, created);
}
#[actix_rt::test]

View file

@ -1,5 +1,3 @@
use std::task::Poll;
use actix_service::{Service, ServiceFactory};
use futures_util::future::{ready, Ready};

View file

@ -1,5 +1,3 @@
use std::task::Poll;
use actix_codec::Framed;
use actix_service::{Service, ServiceFactory};
use futures_core::future::LocalBoxFuture;

View file

@ -1,6 +1,5 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::task::Poll;
use actix_http::{Extensions, Request, Response};
use actix_router::{Path, ResourceDef, Router, Url};

View file

@ -2,7 +2,6 @@ use std::cell::RefCell;
use std::fmt;
use std::future::Future;
use std::rc::Rc;
use std::task::Poll;
use actix_http::{Error, Extensions, Response};
use actix_router::IntoPattern;

View file

@ -2,7 +2,6 @@ use std::cell::RefCell;
use std::fmt;
use std::future::Future;
use std::rc::Rc;
use std::task::Poll;
use actix_http::Extensions;
use actix_router::{ResourceDef, Router};