From 9e95efcc1604e88825c9fdfca19378b911b28f40 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 18 Nov 2019 18:42:27 +0600 Subject: [PATCH] migrate client to std::future --- actix-http/src/client/connection.rs | 177 ++++++------ actix-http/src/client/connector.rs | 177 +++++++----- actix-http/src/client/h1proto.rs | 283 +++++++++---------- actix-http/src/client/h2proto.rs | 254 ++++++++--------- actix-http/src/client/pool.rs | 408 +++++++++++++--------------- actix-http/src/h1/utils.rs | 4 +- actix-http/src/lib.rs | 6 +- actix-http/src/test.rs | 37 ++- actix-http/src/ws/transport.rs | 26 +- 9 files changed, 674 insertions(+), 698 deletions(-) diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index d2b94b3e5..70ffff6c0 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -1,9 +1,10 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{fmt, io, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{Buf, Bytes}; -use futures::future::{err, Either, Future, FutureResult}; -use futures::Poll; +use futures::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready}; use h2::client::SendRequest; use crate::body::MessageBody; @@ -22,7 +23,7 @@ pub(crate) enum ConnectionType { pub trait Connection { type Io: AsyncRead + AsyncWrite; - type Future: Future; + type Future: Future>; fn protocol(&self) -> Protocol; @@ -34,15 +35,16 @@ pub trait Connection { ) -> Self::Future; type TunnelFuture: Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, + Output = Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed fn open_tunnel>(self, head: H) -> Self::TunnelFuture; } -pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { +pub(crate) trait ConnectionLifetime: + AsyncRead + AsyncWrite + Unpin + 'static +{ /// Close connection fn close(&mut self); @@ -91,11 +93,11 @@ impl IoConnection { impl Connection for IoConnection where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = T; type Future = - Box>; + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self.io { @@ -111,38 +113,30 @@ where body: B, ) -> Self::Future { match self.io.take().unwrap() { - ConnectionType::H1(io) => Box::new(h1proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), - ConnectionType::H2(io) => Box::new(h2proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), + ConnectionType::H1(io) => { + h1proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } + ConnectionType::H2(io) => { + h2proto::send_request(io, head.into(), body, self.created, self.pool) + .boxed_local() + } } } type TunnelFuture = Either< - Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >, - FutureResult<(ResponseHead, Framed), SendRequestError>, + Ready), SendRequestError>>, >; /// Send request, returns Response and Framed fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { match self.io.take().unwrap() { ConnectionType::H1(io) => { - Either::A(Box::new(h1proto::open_tunnel(io, head.into()))) + Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local()) } ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { @@ -152,7 +146,7 @@ where None, )); } - Either::B(err(SendRequestError::TunnelNotSupported)) + Either::Right(err(SendRequestError::TunnelNotSupported)) } } } @@ -166,12 +160,12 @@ pub(crate) enum EitherConnection { impl Connection for EitherConnection where - A: AsyncRead + AsyncWrite + 'static, - B: AsyncRead + AsyncWrite + 'static, + A: AsyncRead + AsyncWrite + Unpin + 'static, + B: AsyncRead + AsyncWrite + Unpin + 'static, { type Io = EitherIo; type Future = - Box>; + LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; fn protocol(&self) -> Protocol { match self { @@ -191,24 +185,22 @@ where } } - type TunnelFuture = Box< - dyn Future< - Item = (ResponseHead, Framed), - Error = SendRequestError, - >, + type TunnelFuture = LocalBoxFuture< + 'static, + Result<(ResponseHead, Framed), SendRequestError>, >; /// Send request, returns Response and Framed fn open_tunnel>(self, head: H) -> Self::TunnelFuture { match self { - EitherConnection::A(con) => Box::new( - con.open_tunnel(head) - .map(|(head, framed)| (head, framed.map_io(EitherIo::A))), - ), - EitherConnection::B(con) => Box::new( - con.open_tunnel(head) - .map(|(head, framed)| (head, framed.map_io(EitherIo::B))), - ), + EitherConnection::A(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A)))) + .boxed_local(), + EitherConnection::B(con) => con + .open_tunnel(head) + .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B)))) + .boxed_local(), } } } @@ -218,24 +210,22 @@ pub enum EitherIo { B(B), } -impl io::Read for EitherIo -where - A: io::Read, - B: io::Read, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - EitherIo::A(ref mut val) => val.read(buf), - EitherIo::B(ref mut val) => val.read(buf), - } - } -} - impl AsyncRead for EitherIo where - A: AsyncRead, - B: AsyncRead, + A: AsyncRead + Unpin, + B: AsyncRead + Unpin, { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.get_mut() { + EitherIo::A(ref mut val) => Pin::new(val).poll_read(cx, buf), + EitherIo::B(ref mut val) => Pin::new(val).poll_read(cx, buf), + } + } + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { match self { EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), @@ -244,45 +234,50 @@ where } } -impl io::Write for EitherIo -where - A: io::Write, - B: io::Write, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - EitherIo::A(ref mut val) => val.write(buf), - EitherIo::B(ref mut val) => val.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - EitherIo::A(ref mut val) => val.flush(), - EitherIo::B(ref mut val) => val.flush(), - } - } -} - impl AsyncWrite for EitherIo where - A: AsyncWrite, - B: AsyncWrite, + A: AsyncWrite + Unpin, + B: AsyncWrite + Unpin, { - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self { - EitherIo::A(ref mut val) => val.shutdown(), - EitherIo::B(ref mut val) => val.shutdown(), + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + EitherIo::A(ref mut val) => Pin::new(val).poll_write(cx, buf), + EitherIo::B(ref mut val) => Pin::new(val).poll_write(cx, buf), } } - fn write_buf(&mut self, buf: &mut U) -> Poll + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + EitherIo::A(ref mut val) => Pin::new(val).poll_flush(cx), + EitherIo::B(ref mut val) => Pin::new(val).poll_flush(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + EitherIo::A(ref mut val) => Pin::new(val).poll_shutdown(cx), + EitherIo::B(ref mut val) => Pin::new(val).poll_shutdown(cx), + } + } + + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut U, + ) -> Poll> where Self: Sized, { - match self { - EitherIo::A(ref mut val) => val.write_buf(buf), - EitherIo::B(ref mut val) => val.write_buf(buf), + match self.get_mut() { + EitherIo::A(ref mut val) => Pin::new(val).poll_write_buf(cx, buf), + EitherIo::B(ref mut val) => Pin::new(val).poll_write_buf(cx, buf), } } } diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs index 4ae28ba68..7421cb02e 100644 --- a/actix-http/src/client/connector.rs +++ b/actix-http/src/client/connector.rs @@ -11,6 +11,7 @@ use actix_connect::{ }; use actix_service::{apply_fn, Service}; use actix_utils::timeout::{TimeoutError, TimeoutService}; +use futures::future::Ready; use http::Uri; use tokio_net::tcp::TcpStream; @@ -116,12 +117,13 @@ impl Connector { /// Use custom connector. pub fn connector(self, connector: T1) -> Connector where - U1: AsyncRead + AsyncWrite + fmt::Debug, + U1: AsyncRead + AsyncWrite + Unpin + fmt::Debug, T1: Service< Request = TcpConnect, Response = TcpConnection, Error = actix_connect::ConnectError, > + Clone, + T1::Future: Unpin, { Connector { connector, @@ -138,13 +140,12 @@ impl Connector { impl Connector where - U: AsyncRead + AsyncWrite + fmt::Debug + 'static, + U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, T: Service< Request = TcpConnect, Response = TcpConnection, Error = actix_connect::ConnectError, - > + Clone - + 'static, + > + 'static, { /// Connection timeout, i.e. max time to connect to remote host including dns name resolution. /// Set to 1 second by default. @@ -220,7 +221,7 @@ where { let connector = TimeoutService::new( self.timeout, - apply_fn(self.connector, |msg: Connect, srv| { + apply_fn(UnpinWrapper(self.connector), |msg: Connect, srv| { srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) }) .map_err(ConnectError::from) @@ -337,10 +338,48 @@ where } } +#[derive(Clone)] +struct UnpinWrapper(T); + +impl Unpin for UnpinWrapper {} + +impl Service for UnpinWrapper { + type Request = T::Request; + type Response = T::Response; + type Error = T::Error; + type Future = UnpinWrapperFut; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: T::Request) -> Self::Future { + UnpinWrapperFut { + fut: self.0.call(req), + } + } +} + +struct UnpinWrapperFut { + fut: T::Future, +} + +impl Unpin for UnpinWrapperFut {} + +impl Future for UnpinWrapperFut { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + unsafe { Pin::new_unchecked(&mut self.get_mut().fut) }.poll(cx) + } +} + #[cfg(not(any(feature = "ssl", feature = "rust-tls")))] mod connect_impl { - use futures::future::{err, Either, FutureResult}; - use futures::Poll; + use std::task::{Context, Poll}; + + use futures::future::{err, Either, Ready}; + use futures::ready; use super::*; use crate::client::connection::IoConnection; @@ -349,7 +388,7 @@ mod connect_impl { where Io: AsyncRead + AsyncWrite + 'static, T: Service - + Clone + + Unpin + 'static, { pub(crate) tcp_pool: ConnectionPool, @@ -359,7 +398,7 @@ mod connect_impl { where Io: AsyncRead + AsyncWrite + 'static, T: Service - + Clone + + Unpin + 'static, { fn clone(&self) -> Self { @@ -371,29 +410,30 @@ mod connect_impl { impl Service for InnerConnector where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + + Unpin + 'static, + T::Future: Unpin, { type Request = Connect; type Response = IoConnection; type Error = ConnectError; type Future = Either< as Service>::Future, - FutureResult, ConnectError>, + Ready, ConnectError>>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.tcp_pool.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.tcp_pool.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { match req.uri.scheme_str() { Some("https") | Some("wss") => { - Either::B(err(ConnectError::SslIsNotSupported)) + Either::Right(err(ConnectError::SslIsNotSupported)) } - _ => Either::A(self.tcp_pool.call(req)), + _ => Either::Left(self.tcp_pool.call(req)), } } } @@ -403,18 +443,20 @@ mod connect_impl { mod connect_impl { use std::marker::PhantomData; - use futures::future::{Either, FutureResult}; - use futures::{Async, Future, Poll}; + use futures::future::Either; + use futures::ready; use super::*; use crate::client::connection::EitherConnection; pub(crate) struct InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service, T2: Service, + T1::Future: Unpin, + T2::Future: Unpin, { pub(crate) tcp_pool: ConnectionPool, pub(crate) ssl_pool: ConnectionPool, @@ -422,14 +464,16 @@ mod connect_impl { impl Clone for InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service - + Clone + + Unpin + 'static, T2: Service - + Clone + + Unpin + 'static, + T1::Future: Unpin, + T2::Future: Unpin, { fn clone(&self) -> Self { InnerConnector { @@ -441,52 +485,50 @@ mod connect_impl { impl Service for InnerConnector where - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T1: Service - + Clone + + Unpin + 'static, T2: Service - + Clone + + Unpin + 'static, + T1::Future: Unpin, + T2::Future: Unpin, { type Request = Connect; type Response = EitherConnection; type Error = ConnectError; type Future = Either< - FutureResult, - Either< - InnerConnectorResponseA, - InnerConnectorResponseB, - >, + InnerConnectorResponseA, + InnerConnectorResponseB, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.tcp_pool.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.tcp_pool.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { match req.uri.scheme_str() { - Some("https") | Some("wss") => { - Either::B(Either::B(InnerConnectorResponseB { - fut: self.ssl_pool.call(req), - _t: PhantomData, - })) - } - _ => Either::B(Either::A(InnerConnectorResponseA { + Some("https") | Some("wss") => Either::B(InnerConnectorResponseB { + fut: self.ssl_pool.call(req), + _t: PhantomData, + }), + _ => Either::A(InnerConnectorResponseA { fut: self.tcp_pool.call(req), _t: PhantomData, - })), + }), } } } pub(crate) struct InnerConnectorResponseA where - Io1: AsyncRead + AsyncWrite + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + + Unpin + 'static, + T::Future: Unpin, { fut: as Service>::Future, _t: PhantomData, @@ -495,28 +537,29 @@ mod connect_impl { impl Future for InnerConnectorResponseA where T: Service - + Clone + + Unpin + 'static, - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + T::Future: Unpin, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = EitherConnection; - type Error = ConnectError; + type Output = Result, ConnectError>; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(|res| EitherConnection::A(res)), + ) } } pub(crate) struct InnerConnectorResponseB where - Io2: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + + Unpin + 'static, + T::Future: Unpin, { fut: as Service>::Future, _t: PhantomData, @@ -525,19 +568,19 @@ mod connect_impl { impl Future for InnerConnectorResponseB where T: Service - + Clone + + Unpin + 'static, - Io1: AsyncRead + AsyncWrite + 'static, - Io2: AsyncRead + AsyncWrite + 'static, + T::Future: Unpin, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = EitherConnection; - type Error = ConnectError; + type Output = Result, ConnectError>; - fn poll(&mut self) -> Poll { - match self.fut.poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Poll::Ready( + ready!(Pin::new(&mut self.get_mut().fut).poll(cx)) + .map(|res| EitherConnection::B(res)), + ) } } } diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index 14984253b..f8902a0ef 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -6,8 +6,8 @@ use std::{io, time}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use bytes::{BufMut, Bytes, BytesMut}; -use futures::future::{ok, Either}; -use futures::{Sink, Stream}; +use futures::future::{ok, poll_fn, Either}; +use futures::{Sink, SinkExt, Stream, StreamExt}; use crate::error::PayloadError; use crate::h1; @@ -21,15 +21,15 @@ use super::error::{ConnectError, SendRequestError}; use super::pool::Acquired; use crate::body::{BodySize, MessageBody}; -pub(crate) fn send_request( +pub(crate) async fn send_request( io: T, mut head: RequestHeadType, body: B, created: time::Instant, pool: Option>, -) -> impl Future +) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { // set request host header @@ -65,68 +65,98 @@ where io: Some(io), }; - let len = body.size(); - // create Framed and send request - Framed::new(io, h1::ClientCodec::default()) - .send((head, len).into()) - .from_err() - // send request body - .and_then(move |framed| match body.size() { - BodySize::None | BodySize::Empty | BodySize::Sized(0) => { - Either::A(ok(framed)) - } - _ => Either::B(SendBody::new(body, framed)), - }) - // read response and init read body - .and_then(|framed| { - framed - .into_future() - .map_err(|(e, _)| SendRequestError::from(e)) - .and_then(|(item, framed)| { - if let Some(res) = item { - match framed.get_codec().message_type() { - h1::MessageType::None => { - let force_close = !framed.get_codec().keepalive(); - release_connection(framed, force_close); - Ok((res, Payload::None)) - } - _ => { - let pl: PayloadStream = Box::new(PlStream::new(framed)); - Ok((res, pl.into())) - } - } - } else { - Err(ConnectError::Disconnected.into()) - } - }) - }) + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, body.size()).into()).await?; + + // send request body + match body.size() { + BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), + _ => send_body(body, &mut framed).await?, + }; + + // read response and init read body + let (head, framed) = if let (Some(result), framed) = framed.into_future().await { + let item = result.map_err(SendRequestError::from)?; + (item, framed) + } else { + return Err(SendRequestError::from(ConnectError::Disconnected)); + }; + + match framed.get_codec().message_type() { + h1::MessageType::None => { + let force_close = !framed.get_codec().keepalive(); + release_connection(framed, force_close); + Ok((head, Payload::None)) + } + _ => { + let pl: PayloadStream = PlStream::new(framed).boxed_local(); + Ok((head, pl.into())) + } + } } -pub(crate) fn open_tunnel( +pub(crate) async fn open_tunnel( io: T, head: RequestHeadType, -) -> impl Future), Error = SendRequestError> +) -> Result<(ResponseHead, Framed), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { // create Framed and send request - Framed::new(io, h1::ClientCodec::default()) - .send((head, BodySize::None).into()) - .from_err() - // read response - .and_then(|framed| { - framed - .into_future() - .map_err(|(e, _)| SendRequestError::from(e)) - .and_then(|(head, framed)| { - if let Some(head) = head { - Ok((head, framed)) + let mut framed = Framed::new(io, h1::ClientCodec::default()); + framed.send((head, BodySize::None).into()).await?; + + // read response + if let (Some(result), framed) = framed.into_future().await { + let head = result.map_err(SendRequestError::from)?; + Ok((head, framed)) + } else { + Err(SendRequestError::from(ConnectError::Disconnected)) + } +} + +/// send request body to the peer +pub(crate) async fn send_body( + mut body: B, + framed: &mut Framed, +) -> Result<(), SendRequestError> +where + I: ConnectionLifetime, + B: MessageBody, +{ + let mut eof = false; + while !eof { + while !eof && !framed.is_write_buf_full() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(result) => { + framed.write(h1::Message::Chunk(Some(result?)))?; + } + None => { + eof = true; + framed.write(h1::Message::Chunk(None))?; + } + } + } + + if !framed.is_write_buf_empty() { + poll_fn(|cx| match framed.flush(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => { + if !framed.is_write_buf_full() { + Poll::Ready(Ok(())) } else { - Err(SendRequestError::from(ConnectError::Disconnected)) + Poll::Pending } - }) - }) + } + }) + .await?; + } + } + + SinkExt::flush(framed).await?; + Ok(()) } #[doc(hidden)] @@ -137,7 +167,10 @@ pub struct H1Connection { pool: Option>, } -impl ConnectionLifetime for H1Connection { +impl ConnectionLifetime for H1Connection +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ /// Close connection fn close(&mut self) { if let Some(mut pool) = self.pool.take() { @@ -165,98 +198,41 @@ impl ConnectionLifetime for H1Connection } } -impl io::Read for H1Connection { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.io.as_mut().unwrap().read(buf) +impl AsyncRead for H1Connection { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) } } -impl AsyncRead for H1Connection {} - -impl io::Write for H1Connection { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.io.as_mut().unwrap().write(buf) +impl AsyncWrite for H1Connection { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.io.as_mut().unwrap().flush() + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_flush(cx) } -} -impl AsyncWrite for H1Connection { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.as_mut().unwrap().shutdown() - } -} - -/// Future responsible for sending request body to the peer -pub(crate) struct SendBody { - body: Option, - framed: Option>, - flushed: bool, -} - -impl SendBody -where - I: AsyncRead + AsyncWrite + 'static, - B: MessageBody, -{ - pub(crate) fn new(body: B, framed: Framed) -> Self { - SendBody { - body: Some(body), - framed: Some(framed), - flushed: true, - } - } -} - -impl Future for SendBody -where - I: ConnectionLifetime, - B: MessageBody, -{ - type Item = Framed; - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - let mut body_ready = true; - loop { - while body_ready - && self.body.is_some() - && !self.framed.as_ref().unwrap().is_write_buf_full() - { - match self.body.as_mut().unwrap().poll_next()? { - Async::Ready(item) => { - // check if body is done - if item.is_none() { - let _ = self.body.take(); - } - self.flushed = false; - self.framed - .as_mut() - .unwrap() - .force_send(h1::Message::Chunk(item))?; - break; - } - Async::NotReady => body_ready = false, - } - } - - if !self.flushed { - match self.framed.as_mut().unwrap().poll_complete()? { - Async::Ready(_) => { - self.flushed = true; - continue; - } - Async::NotReady => return Ok(Async::NotReady), - } - } - - if self.body.is_none() { - return Ok(Async::Ready(self.framed.take().unwrap())); - } - return Ok(Async::NotReady); - } + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) } } @@ -273,23 +249,24 @@ impl PlStream { } impl Stream for PlStream { - type Item = Bytes; - type Error = PayloadError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - match self.framed.as_mut().unwrap().poll()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(Some(chunk)) => { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.get_mut(); + + match this.framed.as_mut().unwrap().next_item(cx)? { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(chunk)) => { if let Some(chunk) = chunk { - Ok(Async::Ready(Some(chunk))) + Poll::Ready(Some(Ok(chunk))) } else { - let framed = self.framed.take().unwrap(); + let framed = this.framed.take().unwrap(); let force_close = !framed.get_codec().keepalive(); release_connection(framed, force_close); - Ok(Async::Ready(None)) + Poll::Ready(None) } } - Async::Ready(None) => Ok(Async::Ready(None)), + Poll::Ready(None) => Poll::Ready(None), } } } diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs index 50d74fe1b..25299fd61 100644 --- a/actix-http/src/client/h2proto.rs +++ b/actix-http/src/client/h2proto.rs @@ -5,7 +5,7 @@ use std::time; use actix_codec::{AsyncRead, AsyncWrite}; use bytes::Bytes; -use futures::future::{err, Either}; +use futures::future::{err, poll_fn, Either}; use h2::{client::SendRequest, SendStream}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; use http::{request::Request, HttpTryFrom, Method, Version}; @@ -19,15 +19,15 @@ use super::connection::{ConnectionType, IoConnection}; use super::error::SendRequestError; use super::pool::Acquired; -pub(crate) fn send_request( - io: SendRequest, +pub(crate) async fn send_request( + mut io: SendRequest, head: RequestHeadType, body: B, created: time::Instant, pool: Option>, -) -> impl Future +) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { trace!("Sending client request: {:?} {:?}", head, body.size()); @@ -38,158 +38,138 @@ where _ => false, }; - io.ready() - .map_err(SendRequestError::from) - .and_then(move |mut io| { - let mut req = Request::new(()); - *req.uri_mut() = head.as_ref().uri.clone(); - *req.method_mut() = head.as_ref().method.clone(); - *req.version_mut() = Version::HTTP_2; + let mut req = Request::new(()); + *req.uri_mut() = head.as_ref().uri.clone(); + *req.method_mut() = head.as_ref().method.clone(); + *req.version_mut() = Version::HTTP_2; - let mut skip_len = true; - // let mut has_date = false; + let mut skip_len = true; + // let mut has_date = false; - // Content length - let _ = match length { - BodySize::None => None, - BodySize::Stream => { - skip_len = false; - None - } - BodySize::Empty => req - .headers_mut() - .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), - BodySize::Sized(len) => req.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), - BodySize::Sized64(len) => req.headers_mut().insert( - CONTENT_LENGTH, - HeaderValue::try_from(format!("{}", len)).unwrap(), - ), - }; + // Content length + let _ = match length { + BodySize::None => None, + BodySize::Stream => { + skip_len = false; + None + } + BodySize::Empty => req + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodySize::Sized(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + BodySize::Sized64(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + }; - // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. - let (head, extra_headers) = match head { - RequestHeadType::Owned(head) => { - (RequestHeadType::Owned(head), HeaderMap::new()) - } - RequestHeadType::Rc(head, extra_headers) => ( - RequestHeadType::Rc(head, None), - extra_headers.unwrap_or_else(HeaderMap::new), - ), - }; + // Extracting extra headers from RequestHeadType. HeaderMap::new() does not allocate. + let (head, extra_headers) = match head { + RequestHeadType::Owned(head) => (RequestHeadType::Owned(head), HeaderMap::new()), + RequestHeadType::Rc(head, extra_headers) => ( + RequestHeadType::Rc(head, None), + extra_headers.unwrap_or_else(HeaderMap::new), + ), + }; - // merging headers from head and extra headers. - let headers = head - .as_ref() - .headers - .iter() - .filter(|(name, _)| !extra_headers.contains_key(*name)) - .chain(extra_headers.iter()); + // merging headers from head and extra headers. + let headers = head + .as_ref() + .headers + .iter() + .filter(|(name, _)| !extra_headers.contains_key(*name)) + .chain(extra_headers.iter()); - // copy headers - for (key, value) in headers { - match *key { - CONNECTION | TRANSFER_ENCODING => continue, // http2 specific - CONTENT_LENGTH if skip_len => continue, - // DATE => has_date = true, - _ => (), - } - req.headers_mut().append(key, value.clone()); + // copy headers + for (key, value) in headers { + match *key { + CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + CONTENT_LENGTH if skip_len => continue, + // DATE => has_date = true, + _ => (), + } + req.headers_mut().append(key, value.clone()); + } + + let res = poll_fn(|cx| io.poll_ready(cx)).await; + if let Err(e) = res { + release(io, pool, created, e.is_io()); + return Err(SendRequestError::from(e)); + } + + let resp = match io.send_request(req, eof) { + Ok((fut, send)) => { + release(io, pool, created, false); + + if !eof { + send_body(body, send).await?; } + fut.await.map_err(SendRequestError::from)? + } + Err(e) => { + release(io, pool, created, e.is_io()); + return Err(e.into()); + } + }; - match io.send_request(req, eof) { - Ok((res, send)) => { - release(io, pool, created, false); + let (parts, body) = resp.into_parts(); + let payload = if head_req { Payload::None } else { body.into() }; - if !eof { - Either::A(Either::B( - SendBody { - body, - send, - buf: None, - } - .and_then(move |_| res.map_err(SendRequestError::from)), - )) - } else { - Either::B(res.map_err(SendRequestError::from)) - } - } - Err(e) => { - release(io, pool, created, e.is_io()); - Either::A(Either::A(err(e.into()))) - } - } - }) - .and_then(move |resp| { - let (parts, body) = resp.into_parts(); - let payload = if head_req { Payload::None } else { body.into() }; - - let mut head = ResponseHead::new(parts.status); - head.version = parts.version; - head.headers = parts.headers.into(); - Ok((head, payload)) - }) - .from_err() + let mut head = ResponseHead::new(parts.status); + head.version = parts.version; + head.headers = parts.headers.into(); + Ok((head, payload)) } -struct SendBody { - body: B, - send: SendStream, - buf: Option, -} - -impl Future for SendBody { - type Item = (); - type Error = SendRequestError; - - fn poll(&mut self) -> Poll { - loop { - if self.buf.is_none() { - match self.body.poll_next() { - Ok(Async::Ready(Some(buf))) => { - self.send.reserve_capacity(buf.len()); - self.buf = Some(buf); - } - Ok(Async::Ready(None)) => { - if let Err(e) = self.send.send_data(Bytes::new(), true) { - return Err(e.into()); - } - self.send.reserve_capacity(0); - return Ok(Async::Ready(())); - } - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => return Err(e.into()), +async fn send_body( + mut body: B, + mut send: SendStream, +) -> Result<(), SendRequestError> { + let mut buf = None; + loop { + if buf.is_none() { + match poll_fn(|cx| body.poll_next(cx)).await { + Some(Ok(b)) => { + send.reserve_capacity(b.len()); + buf = Some(b); } - } - - match self.send.poll_capacity() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(None)) => return Ok(Async::Ready(())), - Ok(Async::Ready(Some(cap))) => { - let mut buf = self.buf.take().unwrap(); - let len = buf.len(); - let bytes = buf.split_to(std::cmp::min(cap, len)); - - if let Err(e) = self.send.send_data(bytes, false) { + Some(Err(e)) => return Err(e.into()), + None => { + if let Err(e) = send.send_data(Bytes::new(), true) { return Err(e.into()); - } else { - if !buf.is_empty() { - self.send.reserve_capacity(buf.len()); - self.buf = Some(buf); - } - continue; } + send.reserve_capacity(0); + return Ok(()); } - Err(e) => return Err(e.into()), } } + + match poll_fn(|cx| send.poll_capacity(cx)).await { + None => return Ok(()), + Some(Ok(cap)) => { + let b = buf.as_mut().unwrap(); + let len = b.len(); + let bytes = b.split_to(std::cmp::min(cap, len)); + + if let Err(e) = send.send_data(bytes, false) { + return Err(e.into()); + } else { + if !b.is_empty() { + send.reserve_capacity(b.len()); + } + continue; + } + } + Some(Err(e)) => return Err(e.into()), + } } } // release SendRequest object -fn release( +fn release( io: SendRequest, pool: Option>, created: time::Instant, diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs index 4d02e0a13..ee4c4ab9f 100644 --- a/actix-http/src/client/pool.rs +++ b/actix-http/src/client/pool.rs @@ -9,11 +9,10 @@ use std::time::{Duration, Instant}; use actix_codec::{AsyncRead, AsyncWrite}; use actix_service::Service; -use actix_utils::oneshot; -use actix_utils::task::LocalWaker; +use actix_utils::{oneshot, task::LocalWaker}; use bytes::Bytes; -use futures::future::{err, ok, Either, FutureResult}; -use h2::client::{handshake, Handshake}; +use futures::future::{err, ok, poll_fn, Either, FutureExt, LocalBoxFuture, Ready}; +use h2::client::{handshake, Connection, SendRequest}; use hashbrown::HashMap; use http::uri::Authority; use indexmap::IndexSet; @@ -43,17 +42,15 @@ impl From for Key { } /// Connections pool -pub(crate) struct ConnectionPool( - T, - Rc>>, -); +pub(crate) struct ConnectionPool(Rc>, Rc>>); impl ConnectionPool where Io: AsyncRead + AsyncWrite + 'static, T: Service - + Clone + + Unpin + 'static, + T::Future: Unpin, { pub(crate) fn new( connector: T, @@ -63,7 +60,7 @@ where limit: usize, ) -> Self { ConnectionPool( - connector, + Rc::new(RefCell::new(connector)), Rc::new(RefCell::new(Inner { conn_lifetime, conn_keep_alive, @@ -73,7 +70,7 @@ where waiters: Slab::new(), waiters_queue: IndexSet::new(), available: HashMap::new(), - task: None, + waker: LocalWaker::new(), })), ) } @@ -81,8 +78,7 @@ where impl Clone for ConnectionPool where - T: Clone, - Io: AsyncRead + AsyncWrite + 'static, + Io: 'static, { fn clone(&self) -> Self { ConnectionPool(self.0.clone(), self.1.clone()) @@ -91,86 +87,118 @@ where impl Service for ConnectionPool where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, T: Service - + Clone + + Unpin + 'static, + T::Future: Unpin, { type Request = Connect; type Response = IoConnection; type Error = ConnectError; - type Future = Either< - FutureResult, - Either, OpenConnection>, - >; + type Future = LocalBoxFuture<'static, Result, ConnectError>>; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.0.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.0.poll_ready(cx) } fn call(&mut self, req: Connect) -> Self::Future { - let key = if let Some(authority) = req.uri.authority_part() { - authority.clone().into() - } else { - return Either::A(err(ConnectError::Unresolverd)); + // start support future + tokio_executor::current_thread::spawn(ConnectorPoolSupport { + connector: self.0.clone(), + inner: self.1.clone(), + }); + + let mut connector = self.0.clone(); + let inner = self.1.clone(); + + let fut = async move { + let key = if let Some(authority) = req.uri.authority_part() { + authority.clone().into() + } else { + return Err(ConnectError::Unresolverd); + }; + + // acquire connection + match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { + Acquire::Acquired(io, created) => { + // use existing connection + return Ok(IoConnection::new( + io, + created, + Some(Acquired(key, Some(inner))), + )); + } + Acquire::Available => { + // open tcp connection + let (io, proto) = connector.call(req).await?; + + let guard = OpenGuard::new(key, inner); + + if proto == Protocol::Http1 { + Ok(IoConnection::new( + ConnectionType::H1(io), + Instant::now(), + Some(guard.consume()), + )) + } else { + let (snd, connection) = handshake(io).await?; + tokio_executor::current_thread::spawn(connection.map(|_| ())); + Ok(IoConnection::new( + ConnectionType::H2(snd), + Instant::now(), + Some(guard.consume()), + )) + } + } + _ => { + // connection is not available, wait + let (rx, token) = inner.borrow_mut().wait_for(req); + + let guard = WaiterGuard::new(key, token, inner); + let res = match rx.await { + Err(_) => Err(ConnectError::Disconnected), + Ok(res) => res, + }; + guard.consume(); + res + } + } }; - // acquire connection - match self.1.as_ref().borrow_mut().acquire(&key) { - Acquire::Acquired(io, created) => { - // use existing connection - return Either::A(ok(IoConnection::new( - io, - created, - Some(Acquired(key, Some(self.1.clone()))), - ))); - } - Acquire::Available => { - // open new connection - return Either::B(Either::B(OpenConnection::new( - key, - self.1.clone(), - self.0.call(req), - ))); - } - _ => (), - } - - // connection is not available, wait - let (rx, token, support) = self.1.as_ref().borrow_mut().wait_for(req); - - // start support future - if !support { - self.1.as_ref().borrow_mut().task = Some(AtomicTask::new()); - tokio_executor::current_thread::spawn(ConnectorPoolSupport { - connector: self.0.clone(), - inner: self.1.clone(), - }) - } - - Either::B(Either::A(WaitForConnection { - rx, - key, - token, - inner: Some(self.1.clone()), - })) + fut.boxed_local() } } -#[doc(hidden)] -pub struct WaitForConnection +struct WaiterGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { key: Key, token: usize, - rx: oneshot::Receiver, ConnectError>>, inner: Option>>>, } -impl Drop for WaitForConnection +impl WaiterGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn new(key: Key, token: usize, inner: Rc>>) -> Self { + Self { + key, + token, + inner: Some(inner), + } + } + + fn consume(mut self) { + let _ = self.inner.take(); + } +} + +impl Drop for WaiterGuard +where + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { if let Some(i) = self.inner.take() { @@ -181,113 +209,43 @@ where } } -impl Future for WaitForConnection +struct OpenGuard where - Io: AsyncRead + AsyncWrite, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { - type Item = IoConnection; - type Error = ConnectError; - - fn poll(&mut self) -> Poll { - match self.rx.poll() { - Ok(Async::Ready(item)) => match item { - Err(err) => Err(err), - Ok(conn) => { - let _ = self.inner.take(); - Ok(Async::Ready(conn)) - } - }, - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(_) => { - let _ = self.inner.take(); - Err(ConnectError::Disconnected) - } - } - } -} - -#[doc(hidden)] -pub struct OpenConnection -where - Io: AsyncRead + AsyncWrite + 'static, -{ - fut: F, key: Key, - h2: Option>, inner: Option>>>, } -impl OpenConnection +impl OpenGuard where - F: Future, - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { - fn new(key: Key, inner: Rc>>, fut: F) -> Self { - OpenConnection { + fn new(key: Key, inner: Rc>>) -> Self { + Self { key, - fut, inner: Some(inner), - h2: None, } } + + fn consume(mut self) -> Acquired { + Acquired(self.key.clone(), self.inner.take()) + } } -impl Drop for OpenConnection +impl Drop for OpenGuard where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { - if let Some(inner) = self.inner.take() { - let mut inner = inner.as_ref().borrow_mut(); + if let Some(i) = self.inner.take() { + let mut inner = i.as_ref().borrow_mut(); inner.release(); inner.check_availibility(); } } } -impl Future for OpenConnection -where - F: Future, - Io: AsyncRead + AsyncWrite, -{ - type Item = IoConnection; - type Error = ConnectError; - - fn poll(&mut self) -> Poll { - if let Some(ref mut h2) = self.h2 { - return match h2.poll() { - Ok(Async::Ready((snd, connection))) => { - tokio_executor::current_thread::spawn(connection.map_err(|_| ())); - Ok(Async::Ready(IoConnection::new( - ConnectionType::H2(snd), - Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), - ))) - } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e.into()), - }; - } - - match self.fut.poll() { - Err(err) => Err(err), - Ok(Async::Ready((io, proto))) => { - if proto == Protocol::Http1 { - Ok(Async::Ready(IoConnection::new( - ConnectionType::H1(io), - Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), - ))) - } else { - self.h2 = Some(handshake(io)); - self.poll() - } - } - Ok(Async::NotReady) => Ok(Async::NotReady), - } - } -} - enum Acquire { Acquired(ConnectionType, Instant), Available, @@ -314,7 +272,7 @@ pub(crate) struct Inner { )>, >, waiters_queue: IndexSet<(Key, usize)>, - task: Option, + waker: LocalWaker, } impl Inner { @@ -334,7 +292,7 @@ impl Inner { impl Inner where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { /// connection is not available, wait fn wait_for( @@ -343,7 +301,6 @@ where ) -> ( oneshot::Receiver, ConnectError>>, usize, - bool, ) { let (tx, rx) = oneshot::channel(); @@ -353,10 +310,10 @@ where entry.insert(Some((connect, tx))); assert!(self.waiters_queue.insert((key, token))); - (rx, token, self.task.is_some()) + (rx, token) } - fn acquire(&mut self, key: &Key) -> Acquire { + fn acquire(&mut self, key: &Key, cx: &mut Context) -> Acquire { // check limits if self.limit > 0 && self.acquired >= self.limit { return Acquire::NotAvailable; @@ -384,9 +341,9 @@ where let mut io = conn.io; let mut buf = [0; 2]; if let ConnectionType::H1(ref mut s) = io { - match s.read(&mut buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - Ok(n) if n > 0 => { + match Pin::new(s).poll_read(cx, &mut buf) { + Poll::Pending => (), + Poll::Ready(Ok(n)) if n > 0 => { if let Some(timeout) = self.disconnect_timeout { if let ConnectionType::H1(io) = io { tokio_executor::current_thread::spawn( @@ -396,7 +353,7 @@ where } continue; } - Ok(_) | Err(_) => continue, + _ => continue, } } return Acquire::Acquired(io, conn.created); @@ -431,9 +388,7 @@ where fn check_availibility(&self) { if !self.waiters_queue.is_empty() && self.acquired < self.limit { - if let Some(t) = self.task.as_ref() { - t.notify() - } + self.waker.wake(); } } } @@ -457,17 +412,16 @@ where impl Future for CloseConnection where - T: AsyncWrite, + T: AsyncWrite + Unpin, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll<(), ()> { - match self.timeout.poll() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => match self.io.shutdown() { - Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), - Ok(Async::NotReady) => Ok(Async::NotReady), + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + match Pin::new(&mut self.timeout).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => match Pin::new(&mut self.io).poll_shutdown(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, }, } } @@ -483,16 +437,18 @@ where impl Future for ConnectorPoolSupport where - Io: AsyncRead + AsyncWrite + 'static, - T: Service, - T::Future: 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, + T: Service + + Unpin, + T::Future: Unpin + 'static, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - let mut inner = self.inner.as_ref().borrow_mut(); - inner.task.as_ref().unwrap().register(); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + let mut inner = this.inner.as_ref().borrow_mut(); + inner.waker.register(cx.waker()); // check waiters loop { @@ -507,14 +463,14 @@ where continue; } - match inner.acquire(&key) { + match inner.acquire(&key, cx) { Acquire::NotAvailable => break, Acquire::Acquired(io, created) => { let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1; if let Err(conn) = tx.send(Ok(IoConnection::new( io, created, - Some(Acquired(key.clone(), Some(self.inner.clone()))), + Some(Acquired(key.clone(), Some(this.inner.clone()))), ))) { let (io, created) = conn.unwrap().into_inner(); inner.release_conn(&key, io, created); @@ -526,33 +482,38 @@ where OpenWaitingConnection::spawn( key.clone(), tx, - self.inner.clone(), - self.connector.call(connect), + this.inner.clone(), + this.connector.call(connect), ); } } let _ = inner.waiters_queue.swap_remove_index(0); } - Ok(Async::NotReady) + Poll::Pending } } struct OpenWaitingConnection where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fut: F, key: Key, - h2: Option>, + h2: Option< + LocalBoxFuture< + 'static, + Result<(SendRequest, Connection), h2::Error>, + >, + >, rx: Option, ConnectError>>>, inner: Option>>>, } impl OpenWaitingConnection where - F: Future + 'static, - Io: AsyncRead + AsyncWrite + 'static, + F: Future> + Unpin + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn spawn( key: Key, @@ -572,7 +533,7 @@ where impl Drop for OpenWaitingConnection where - Io: AsyncRead + AsyncWrite + 'static, + Io: AsyncRead + AsyncWrite + Unpin + 'static, { fn drop(&mut self) { if let Some(inner) = self.inner.take() { @@ -585,59 +546,60 @@ where impl Future for OpenWaitingConnection where - F: Future, - Io: AsyncRead + AsyncWrite, + F: Future> + Unpin, + Io: AsyncRead + AsyncWrite + Unpin, { - type Item = (); - type Error = (); + type Output = (); - fn poll(&mut self) -> Poll { - if let Some(ref mut h2) = self.h2 { - return match h2.poll() { - Ok(Async::Ready((snd, connection))) => { - tokio_executor::current_thread::spawn(connection.map_err(|_| ())); - let rx = self.rx.take().unwrap(); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.get_mut(); + + if let Some(ref mut h2) = this.h2 { + return match Pin::new(h2).poll(cx) { + Poll::Ready(Ok((snd, connection))) => { + tokio_executor::current_thread::spawn(connection.map(|_| ())); + let rx = this.rx.take().unwrap(); let _ = rx.send(Ok(IoConnection::new( ConnectionType::H2(snd), Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), + Some(Acquired(this.key.clone(), this.inner.take())), ))); - Ok(Async::Ready(())) + Poll::Ready(()) } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(err) => { - let _ = self.inner.take(); - if let Some(rx) = self.rx.take() { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { let _ = rx.send(Err(ConnectError::H2(err))); } - Err(()) + Poll::Ready(()) } }; } - match self.fut.poll() { - Err(err) => { - let _ = self.inner.take(); - if let Some(rx) = self.rx.take() { + match Pin::new(&mut this.fut).poll(cx) { + Poll::Ready(Err(err)) => { + let _ = this.inner.take(); + if let Some(rx) = this.rx.take() { let _ = rx.send(Err(err)); } - Err(()) + Poll::Ready(()) } - Ok(Async::Ready((io, proto))) => { + Poll::Ready(Ok((io, proto))) => { if proto == Protocol::Http1 { - let rx = self.rx.take().unwrap(); + let rx = this.rx.take().unwrap(); let _ = rx.send(Ok(IoConnection::new( ConnectionType::H1(io), Instant::now(), - Some(Acquired(self.key.clone(), self.inner.take())), + Some(Acquired(this.key.clone(), this.inner.take())), ))); - Ok(Async::Ready(())) + Poll::Ready(()) } else { - self.h2 = Some(handshake(io)); - self.poll() + this.h2 = Some(handshake(io).boxed_local()); + Pin::new(this).poll(cx) } } - Ok(Async::NotReady) => Ok(Async::NotReady), + Poll::Pending => Poll::Pending, } } } @@ -646,7 +608,7 @@ pub(crate) struct Acquired(Key, Option>>>); impl Acquired where - T: AsyncRead + AsyncWrite + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, { pub(crate) fn close(&mut self, conn: IoConnection) { if let Some(inner) = self.1.take() { diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index bc6914d3b..a992089c0 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -55,7 +55,7 @@ where if item.is_none() { let _ = this.body.take(); } - framed.force_send(Message::Chunk(item))?; + framed.write(Message::Chunk(item))?; } Poll::Pending => body_ready = false, } @@ -78,7 +78,7 @@ where // send response if let Some(res) = this.res.take() { - framed.force_send(res)?; + framed.write(res)?; continue; } diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index cf528aeec..4d17347db 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -13,7 +13,7 @@ extern crate log; pub mod body; mod builder; -// pub mod client; +pub mod client; mod cloneable; mod config; pub mod encoding; @@ -32,8 +32,8 @@ pub mod cookie; pub mod error; pub mod h1; pub mod h2; -// pub mod test; -// pub mod ws; +pub mod test; +pub mod ws; pub use self::builder::HttpServiceBuilder; pub use self::config::{KeepAlive, ServiceConfig}; diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs index 817bf480d..26f2c223f 100644 --- a/actix-http/src/test.rs +++ b/actix-http/src/test.rs @@ -1,6 +1,6 @@ //! Test Various helpers for Actix applications to use during testing. use std::fmt::Write as FmtWrite; -use std::io; +use std::io::{self, Read, Write}; use std::pin::Pin; use std::str::FromStr; use std::task::{Context, Poll}; @@ -245,16 +245,33 @@ impl io::Write for TestBuffer { } } -// impl AsyncRead for TestBuffer {} +impl AsyncRead for TestBuffer { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Poll::Ready(self.get_mut().read(buf)) + } +} -// impl AsyncWrite for TestBuffer { -// fn shutdown(&mut self) -> Poll<(), io::Error> { -// Ok(Async::Ready(())) -// } -// fn write_buf(&mut self, _: &mut B) -> Poll { -// Ok(Async::NotReady) -// } -// } +impl AsyncWrite for TestBuffer { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(self.get_mut().write(buf)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} impl IoStream for TestBuffer { fn set_nodelay(&mut self, _nodelay: bool) -> io::Result<()> { diff --git a/actix-http/src/ws/transport.rs b/actix-http/src/ws/transport.rs index da7782be5..c55e2eebd 100644 --- a/actix-http/src/ws/transport.rs +++ b/actix-http/src/ws/transport.rs @@ -1,24 +1,27 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::{IntoService, Service}; use actix_utils::framed::{FramedTransport, FramedTransportError}; -use futures::{Future, Poll}; use super::{Codec, Frame, Message}; pub struct Transport where S: Service + 'static, - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { inner: FramedTransport, } impl Transport where - T: AsyncRead + AsyncWrite, - S: Service, + T: AsyncRead + AsyncWrite + Unpin, + S: Service + Unpin, S::Future: 'static, - S::Error: 'static, + S::Error: Unpin + 'static, { pub fn new>(io: T, service: F) -> Self { Transport { @@ -35,15 +38,14 @@ where impl Future for Transport where - T: AsyncRead + AsyncWrite, - S: Service, + T: AsyncRead + AsyncWrite + Unpin, + S: Service + Unpin, S::Future: 'static, - S::Error: 'static, + S::Error: Unpin + 'static, { - type Item = (); - type Error = FramedTransportError; + type Output = Result<(), FramedTransportError>; - fn poll(&mut self) -> Poll { - self.inner.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + Pin::new(&mut self.inner).poll(cx) } }