use std::{ io, ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, time, }; use actix_codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; use actix_http::{body::MessageBody, h1::ClientCodec, Payload, RequestHeadType, ResponseHead}; use actix_rt::task::JoinHandle; use bytes::Bytes; use futures_core::future::LocalBoxFuture; use h2::client::SendRequest; use super::{error::SendRequestError, h1proto, h2proto, pool::Acquired}; use crate::BoxError; /// Trait alias for types impl [tokio::io::AsyncRead] and [tokio::io::AsyncWrite]. pub trait ConnectionIo: AsyncRead + AsyncWrite + Unpin + 'static {} impl ConnectionIo for T {} /// HTTP client connection pub struct H1Connection { io: Option, created: time::Instant, acquired: Acquired, } impl H1Connection { /// close or release the connection to pool based on flag input pub(super) fn on_release(&mut self, keep_alive: bool) { if keep_alive { self.release(); } else { self.close(); } } /// Close connection fn close(&mut self) { let io = self.io.take().unwrap(); self.acquired.close(ConnectionInnerType::H1(io)); } /// Release this connection to the connection pool fn release(&mut self) { let io = self.io.take().unwrap(); self.acquired .release(ConnectionInnerType::H1(io), self.created); } fn io_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Io> { Pin::new(self.get_mut().io.as_mut().unwrap()) } } impl AsyncRead for H1Connection { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { self.io_pin_mut().poll_read(cx, buf) } } impl AsyncWrite for H1Connection { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { self.io_pin_mut().poll_write(cx, buf) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.io_pin_mut().poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.io_pin_mut().poll_shutdown(cx) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { self.io_pin_mut().poll_write_vectored(cx, bufs) } fn is_write_vectored(&self) -> bool { self.io.as_ref().unwrap().is_write_vectored() } } /// HTTP2 client connection pub struct H2Connection { io: Option, created: time::Instant, acquired: Acquired, } impl Deref for H2Connection { type Target = SendRequest; fn deref(&self) -> &Self::Target { &self.io.as_ref().unwrap().sender } } impl DerefMut for H2Connection { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.io.as_mut().unwrap().sender } } impl H2Connection { /// close or release the connection to pool based on flag input pub(super) fn on_release(&mut self, close: bool) { if close { self.close(); } else { self.release(); } } /// Close connection fn close(&mut self) { let io = self.io.take().unwrap(); self.acquired.close(ConnectionInnerType::H2(io)); } /// Release this connection to the connection pool fn release(&mut self) { let io = self.io.take().unwrap(); self.acquired .release(ConnectionInnerType::H2(io), self.created); } } /// `H2ConnectionInner` has two parts: `SendRequest` and `Connection`. /// /// `Connection` is spawned as an async task on runtime and `H2ConnectionInner` holds a handle /// for this task. Therefore, it can wake up and quit the task when SendRequest is dropped. pub(super) struct H2ConnectionInner { handle: JoinHandle<()>, sender: SendRequest, } impl H2ConnectionInner { pub(super) fn new( sender: SendRequest, connection: h2::client::Connection, ) -> Self { let handle = actix_rt::spawn(async move { let _ = connection.await; }); Self { handle, sender } } } /// Cancel spawned connection task on drop. impl Drop for H2ConnectionInner { fn drop(&mut self) { // TODO: this can end up sending extraneous requests; see if there is a better way to handle if self .sender .send_request(http::Request::new(()), true) .is_err() { self.handle.abort(); } } } /// Unified connection type cover HTTP/1 Plain/TLS and HTTP/2 protocols. #[allow(dead_code)] pub enum Connection> where A: ConnectionIo, B: ConnectionIo, { Tcp(ConnectionType), Tls(ConnectionType), } /// Unified connection type cover Http1/2 protocols pub enum ConnectionType { H1(H1Connection), H2(H2Connection), } /// Helper type for storing connection types in pool. pub(super) enum ConnectionInnerType { H1(Io), H2(H2ConnectionInner), } impl ConnectionType { pub(super) fn from_pool( inner: ConnectionInnerType, created: time::Instant, acquired: Acquired, ) -> Self { match inner { ConnectionInnerType::H1(io) => Self::from_h1(io, created, acquired), ConnectionInnerType::H2(io) => Self::from_h2(io, created, acquired), } } pub(super) fn from_h1(io: Io, created: time::Instant, acquired: Acquired) -> Self { Self::H1(H1Connection { io: Some(io), created, acquired, }) } pub(super) fn from_h2( io: H2ConnectionInner, created: time::Instant, acquired: Acquired, ) -> Self { Self::H2(H2Connection { io: Some(io), created, acquired, }) } } impl Connection where A: ConnectionIo, B: ConnectionIo, { /// Send a request through connection. pub fn send_request( self, head: H, body: RB, ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> where H: Into + 'static, RB: MessageBody + 'static, RB::Error: Into, { Box::pin(async move { match self { Connection::Tcp(ConnectionType::H1(conn)) => { h1proto::send_request(conn, head.into(), body).await } Connection::Tls(ConnectionType::H1(conn)) => { h1proto::send_request(conn, head.into(), body).await } Connection::Tls(ConnectionType::H2(conn)) => { h2proto::send_request(conn, head.into(), body).await } _ => { unreachable!("Plain TCP connection can be used only with HTTP/1.1 protocol") } } }) } /// Send request, returns Response and Framed tunnel. pub fn open_tunnel + 'static>( self, head: H, ) -> LocalBoxFuture< 'static, Result<(ResponseHead, Framed, ClientCodec>), SendRequestError>, > { Box::pin(async move { match self { Connection::Tcp(ConnectionType::H1(ref _conn)) => { let (head, framed) = h1proto::open_tunnel(self, head.into()).await?; Ok((head, framed)) } Connection::Tls(ConnectionType::H1(ref _conn)) => { let (head, framed) = h1proto::open_tunnel(self, head.into()).await?; Ok((head, framed)) } Connection::Tls(ConnectionType::H2(mut conn)) => { conn.release(); Err(SendRequestError::TunnelNotSupported) } Connection::Tcp(ConnectionType::H2(_)) => { unreachable!("Plain Tcp connection can be used only in Http1 protocol") } } }) } } impl AsyncRead for Connection where A: ConnectionIo, B: ConnectionIo, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match self.get_mut() { Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_read(cx, buf), Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_read(cx, buf), _ => unreachable!("H2Connection can not impl AsyncRead trait"), } } } const H2_UNREACHABLE_WRITE: &str = "H2Connection can not impl AsyncWrite trait"; impl AsyncWrite for Connection where A: ConnectionIo, B: ConnectionIo, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match self.get_mut() { Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_write(cx, buf), Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_write(cx, buf), _ => unreachable!("{}", H2_UNREACHABLE_WRITE), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_flush(cx), Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_flush(cx), _ => unreachable!("{}", H2_UNREACHABLE_WRITE), } } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { Connection::Tcp(ConnectionType::H1(conn)) => Pin::new(conn).poll_shutdown(cx), Connection::Tls(ConnectionType::H1(conn)) => Pin::new(conn).poll_shutdown(cx), _ => unreachable!("{}", H2_UNREACHABLE_WRITE), } } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { match self.get_mut() { Connection::Tcp(ConnectionType::H1(conn)) => { Pin::new(conn).poll_write_vectored(cx, bufs) } Connection::Tls(ConnectionType::H1(conn)) => { Pin::new(conn).poll_write_vectored(cx, bufs) } _ => unreachable!("{}", H2_UNREACHABLE_WRITE), } } fn is_write_vectored(&self) -> bool { match *self { Connection::Tcp(ConnectionType::H1(ref conn)) => conn.is_write_vectored(), Connection::Tls(ConnectionType::H1(ref conn)) => conn.is_write_vectored(), _ => unreachable!("{}", H2_UNREACHABLE_WRITE), } } } #[cfg(test)] mod test { use std::{ future::Future, net, time::{Duration, Instant}, }; use actix_rt::{ net::TcpStream, time::{interval, Interval}, }; use super::*; #[actix_rt::test] async fn test_h2_connection_drop() { env_logger::try_init().ok(); let addr = "127.0.0.1:0".parse::().unwrap(); let listener = net::TcpListener::bind(addr).unwrap(); let local = listener.local_addr().unwrap(); std::thread::spawn(move || while listener.accept().is_ok() {}); let tcp = TcpStream::connect(local).await.unwrap(); let (sender, connection) = h2::client::handshake(tcp).await.unwrap(); let conn = H2ConnectionInner::new(sender.clone(), connection); assert!(sender.clone().ready().await.is_ok()); assert!(h2::client::SendRequest::clone(&conn.sender) .ready() .await .is_ok()); drop(conn); struct DropCheck { sender: h2::client::SendRequest, interval: Interval, start_from: Instant, } impl Future for DropCheck { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); match futures_core::ready!(this.sender.poll_ready(cx)) { Ok(()) => { if this.start_from.elapsed() > Duration::from_secs(10) { panic!("connection should be gone and can not be ready"); } else { match this.interval.poll_tick(cx) { Poll::Ready(_) => { // prevents spurious test hang this.interval.reset(); Poll::Pending } Poll::Pending => Poll::Pending, } } } Err(_) => Poll::Ready(()), } } } DropCheck { sender, interval: interval(Duration::from_millis(100)), start_from: Instant::now(), } .await; } }