From ec5c7797325e63d59eaa54597916d72e95343975 Mon Sep 17 00:00:00 2001 From: Maksym Vorobiov Date: Mon, 3 Feb 2020 22:55:49 +0200 Subject: [PATCH] unlink MessageBody from Unpin --- actix-http/src/body.rs | 10 +- actix-http/src/h1/dispatcher.rs | 308 ++++++++++++++++++-------------- actix-http/src/h1/utils.rs | 2 +- 3 files changed, 181 insertions(+), 139 deletions(-) diff --git a/actix-http/src/body.rs b/actix-http/src/body.rs index 74e6e218d..26134723d 100644 --- a/actix-http/src/body.rs +++ b/actix-http/src/body.rs @@ -33,7 +33,7 @@ impl BodySize { } /// Type that provides this trait can be streamed to a peer. -pub trait MessageBody: Unpin { +pub trait MessageBody { fn size(&self) -> BodySize; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>>; @@ -53,14 +53,13 @@ impl MessageBody for () { } } -impl MessageBody for Box { +impl MessageBody for Box { fn size(&self) -> BodySize { self.as_ref().size() } fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - let a: Pin<&mut T> = Pin::new(self.get_mut().as_mut()); - a.poll_next(cx) + unsafe { self.map_unchecked_mut(|boxed| boxed.as_mut()) }.poll_next(cx) } } @@ -70,8 +69,7 @@ impl MessageBody for Box { } fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - let a: Pin<&mut dyn MessageBody> = Pin::new(self.get_mut().as_mut()); - a.poll_next(cx) + unsafe { Pin::new_unchecked(self.get_mut().as_mut()) }.poll_next(cx) } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 0f897561d..7429c50f7 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -10,6 +10,7 @@ use actix_service::Service; use bitflags::bitflags; use bytes::{Buf, BytesMut}; use log::{error, trace}; +use pin_project::pin_project; use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::cloneable::CloneableService; @@ -41,6 +42,7 @@ bitflags! { } } +#[pin_project::pin_project] /// Dispatcher for HTTP/1.1 protocol pub struct Dispatcher where @@ -52,9 +54,11 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { + #[pin] inner: DispatcherState, } +#[pin_project] enum DispatcherState where S: Service, @@ -65,11 +69,12 @@ where U: Service), Response = ()>, U::Error: fmt::Display, { - Normal(InnerDispatcher), - Upgrade(Pin>), + Normal(#[pin] InnerDispatcher), + Upgrade(#[pin] U::Future), None, } +#[pin_project] struct InnerDispatcher where S: Service, @@ -88,6 +93,7 @@ where peer_addr: Option, error: Option, + #[pin] state: State, payload: Option, messages: VecDeque, @@ -107,6 +113,7 @@ enum DispatcherMessage { Error(Response<()>), } +#[pin_project] enum State where S: Service, @@ -114,9 +121,9 @@ where B: MessageBody, { None, - ExpectCall(Pin>), - ServiceCall(Pin>), - SendPayload(ResponseBody), + ExpectCall(#[pin] X::Future), + ServiceCall(#[pin] S::Future), + SendPayload(#[pin] ResponseBody), } impl State @@ -142,6 +149,21 @@ where } } +impl DispatcherState +where + S: Service, + S::Error: Into, + B: MessageBody, + X: Service, + X::Error: Into, + U: Service), Response = ()>, + U::Error: fmt::Display, +{ + fn take(self: Pin<&mut Self>) -> Self { + std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None) + } +} + enum PollResponse { Upgrade(Request), DoNothing, @@ -278,10 +300,11 @@ where } // if checked is set to true, delay disconnect until all tasks have finished. - fn client_disconnected(&mut self) { - self.flags + fn client_disconnected(self: Pin<&mut Self>) { + let this = self.project(); + this.flags .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT); - if let Some(mut payload) = self.payload.take() { + if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); } } @@ -290,16 +313,18 @@ where /// /// true - got whouldblock /// false - didnt get whouldblock - fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result { + #[pin_project::project] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result { if self.write_buf.is_empty() { return Ok(false); } let len = self.write_buf.len(); let mut written = 0; + #[project] + let InnerDispatcher { mut io, write_buf, .. } = self.project(); while written < len { - match Pin::new(&mut self.io) - .poll_write(cx, &self.write_buf[written..]) + match Pin::new(&mut io).poll_write(cx, &write_buf[written..]) { Poll::Ready(Ok(0)) => { return Err(DispatchError::Io(io::Error::new( @@ -312,113 +337,120 @@ where } Poll::Pending => { if written > 0 { - self.write_buf.advance(written); + write_buf.advance(written); } return Ok(true); } Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), } } - if written == self.write_buf.len() { - unsafe { self.write_buf.set_len(0) } + if written == write_buf.len() { + unsafe { write_buf.set_len(0) } } else { - self.write_buf.advance(written); + write_buf.advance(written); } Ok(false) } fn send_response( - &mut self, + self: Pin<&mut Self>, message: Response<()>, body: ResponseBody, ) -> Result, DispatchError> { - self.codec - .encode(Message::Item((message, body.size())), &mut self.write_buf) + let mut this = self.project(); + this.codec + .encode(Message::Item((message, body.size())), &mut this.write_buf) .map_err(|err| { - if let Some(mut payload) = self.payload.take() { + if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); } DispatchError::Io(err) })?; - self.flags.set(Flags::KEEPALIVE, self.codec.keepalive()); + this.flags.set(Flags::KEEPALIVE, this.codec.keepalive()); match body.size() { BodySize::None | BodySize::Empty => Ok(State::None), _ => Ok(State::SendPayload(body)), } } - fn send_continue(&mut self) { - self.write_buf + fn send_continue(self: Pin<&mut Self>) { + self.project().write_buf .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); } + #[pin_project::project] fn poll_response( - &mut self, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Result { loop { - let state = match self.state { - State::None => match self.messages.pop_front() { + let mut this = self.as_mut().project(); + #[project] + let state = match this.state.project() { + State::None => match this.messages.pop_front() { Some(DispatcherMessage::Item(req)) => { - Some(self.handle_request(req, cx)?) + Some(self.as_mut().handle_request(req, cx)?) } Some(DispatcherMessage::Error(res)) => { - Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) + Some(self.as_mut().send_response(res, ResponseBody::Other(Body::Empty))?) } Some(DispatcherMessage::Upgrade(req)) => { return Ok(PollResponse::Upgrade(req)); } None => None, }, - State::ExpectCall(ref mut fut) => { - match fut.as_mut().poll(cx) { + State::ExpectCall(fut) => { + match fut.poll(cx) { Poll::Ready(Ok(req)) => { - self.send_continue(); - self.state = State::ServiceCall(Box::pin(self.service.call(req))); + self.as_mut().send_continue(); + this = self.as_mut().project(); + this.state.set(State::ServiceCall(this.service.call(req))); continue; } Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); - Some(self.send_response(res, body.into_body())?) + Some(self.as_mut().send_response(res, body.into_body())?) } Poll::Pending => None, } } - State::ServiceCall(ref mut fut) => { - match fut.as_mut().poll(cx) { + State::ServiceCall(fut) => { + match fut.poll(cx) { Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); - self.state = self.send_response(res, body)?; + let state = self.as_mut().send_response(res, body)?; + this = self.as_mut().project(); + this.state.set(state); continue; } Poll::Ready(Err(e)) => { let res: Response = e.into().into(); let (res, body) = res.replace_body(()); - Some(self.send_response(res, body.into_body())?) + Some(self.as_mut().send_response(res, body.into_body())?) } Poll::Pending => None, } } - State::SendPayload(ref mut stream) => { - let mut stream = Pin::new(stream); + State::SendPayload(mut stream) => { loop { - if self.write_buf.len() < HW_BUFFER_SIZE { + if this.write_buf.len() < HW_BUFFER_SIZE { match stream.as_mut().poll_next(cx) { Poll::Ready(Some(Ok(item))) => { - self.codec.encode( + this.codec.encode( Message::Chunk(Some(item)), - &mut self.write_buf, + &mut this.write_buf, )?; continue; } Poll::Ready(None) => { - self.codec.encode( + this.codec.encode( Message::Chunk(None), - &mut self.write_buf, + &mut this.write_buf, )?; - self.state = State::None; + this = self.as_mut().project(); + this.state.set(State::None); } Poll::Ready(Some(Err(_))) => { return Err(DispatchError::Unknown) @@ -434,9 +466,11 @@ where } }; + this = self.as_mut().project(); + // set new state if let Some(state) = state { - self.state = state; + this.state.set(state); if !self.state.is_empty() { continue; } @@ -444,7 +478,7 @@ where // if read-backpressure is enabled and we consumed some data. // we may read more data and retry if self.state.is_call() { - if self.poll_request(cx)? { + if self.as_mut().poll_request(cx)? { continue; } } else if !self.messages.is_empty() { @@ -458,16 +492,16 @@ where } fn handle_request( - &mut self, + mut self: Pin<&mut Self>, req: Request, cx: &mut Context<'_>, ) -> Result, DispatchError> { // Handle `EXPECT: 100-Continue` header let req = if req.head().expect() { - let mut task = Box::pin(self.expect.call(req)); - match task.as_mut().poll(cx) { + let mut task = self.as_mut().project().expect.call(req); + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { Poll::Ready(Ok(req)) => { - self.send_continue(); + self.as_mut().send_continue(); req } Poll::Pending => return Ok(State::ExpectCall(task)), @@ -483,8 +517,8 @@ where }; // Call service - let mut task = Box::pin(self.service.call(req)); - match task.as_mut().poll(cx) { + let mut task = self.as_mut().project().service.call(req); + match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) { Poll::Ready(Ok(res)) => { let (res, body) = res.into().replace_body(()); self.send_response(res, body) @@ -500,7 +534,7 @@ where /// Process one incoming requests pub(self) fn poll_request( - &mut self, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Result { // limit a mount of non processed requests @@ -509,24 +543,25 @@ where } let mut updated = false; + let mut this = self.as_mut().project(); loop { - match self.codec.decode(&mut self.read_buf) { + match this.codec.decode(&mut this.read_buf) { Ok(Some(msg)) => { updated = true; - self.flags.insert(Flags::STARTED); + this.flags.insert(Flags::STARTED); match msg { Message::Item(mut req) => { - let pl = self.codec.message_type(); - req.head_mut().peer_addr = self.peer_addr; + let pl = this.codec.message_type(); + req.head_mut().peer_addr = *this.peer_addr; // set on_connect data - if let Some(ref on_connect) = self.on_connect { + if let Some(ref on_connect) = this.on_connect { on_connect.set(&mut req.extensions_mut()); } - if pl == MessageType::Stream && self.upgrade.is_some() { - self.messages.push_back(DispatcherMessage::Upgrade(req)); + if pl == MessageType::Stream && this.upgrade.is_some() { + this.messages.push_back(DispatcherMessage::Upgrade(req)); break; } if pl == MessageType::Payload || pl == MessageType::Stream { @@ -534,41 +569,43 @@ where let (req1, _) = req.replace_payload(crate::Payload::H1(pl)); req = req1; - self.payload = Some(ps); + *this.payload = Some(ps); } // handle request early - if self.state.is_empty() { - self.state = self.handle_request(req, cx)?; + if this.state.is_empty() { + let state = self.as_mut().handle_request(req, cx)?; + this = self.as_mut().project(); + this.state.set(state); } else { - self.messages.push_back(DispatcherMessage::Item(req)); + this.messages.push_back(DispatcherMessage::Item(req)); } } Message::Chunk(Some(chunk)) => { - if let Some(ref mut payload) = self.payload { + if let Some(ref mut payload) = this.payload { payload.feed_data(chunk); } else { error!( "Internal server error: unexpected payload chunk" ); - self.flags.insert(Flags::READ_DISCONNECT); - self.messages.push_back(DispatcherMessage::Error( + this.flags.insert(Flags::READ_DISCONNECT); + this.messages.push_back(DispatcherMessage::Error( Response::InternalServerError().finish().drop_body(), )); - self.error = Some(DispatchError::InternalError); + *this.error = Some(DispatchError::InternalError); break; } } Message::Chunk(None) => { - if let Some(mut payload) = self.payload.take() { + if let Some(mut payload) = this.payload.take() { payload.feed_eof(); } else { error!("Internal server error: unexpected eof"); - self.flags.insert(Flags::READ_DISCONNECT); - self.messages.push_back(DispatcherMessage::Error( + this.flags.insert(Flags::READ_DISCONNECT); + this.messages.push_back(DispatcherMessage::Error( Response::InternalServerError().finish().drop_body(), )); - self.error = Some(DispatchError::InternalError); + *this.error = Some(DispatchError::InternalError); break; } } @@ -576,44 +613,46 @@ where } Ok(None) => break, Err(ParseError::Io(e)) => { - self.client_disconnected(); - self.error = Some(DispatchError::Io(e)); + self.as_mut().client_disconnected(); + this = self.as_mut().project(); + *this.error = Some(DispatchError::Io(e)); break; } Err(e) => { - if let Some(mut payload) = self.payload.take() { + if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::EncodingCorrupted); } // Malformed requests should be responded with 400 - self.messages.push_back(DispatcherMessage::Error( + this.messages.push_back(DispatcherMessage::Error( Response::BadRequest().finish().drop_body(), )); - self.flags.insert(Flags::READ_DISCONNECT); - self.error = Some(e.into()); + this.flags.insert(Flags::READ_DISCONNECT); + *this.error = Some(e.into()); break; } } } - if updated && self.ka_timer.is_some() { - if let Some(expire) = self.codec.config().keep_alive_expire() { - self.ka_expire = expire; + if updated && this.ka_timer.is_some() { + if let Some(expire) = this.codec.config().keep_alive_expire() { + *this.ka_expire = expire; } } Ok(updated) } /// keep-alive timer - fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> { - if self.ka_timer.is_none() { + fn poll_keepalive(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> { + let mut this = self.as_mut().project(); + if this.ka_timer.is_none() { // shutdown timeout - if self.flags.contains(Flags::SHUTDOWN) { - if let Some(interval) = self.codec.config().client_disconnect_timer() { - self.ka_timer = Some(delay_until(interval)); + if this.flags.contains(Flags::SHUTDOWN) { + if let Some(interval) = this.codec.config().client_disconnect_timer() { + *this.ka_timer = Some(delay_until(interval)); } else { - self.flags.insert(Flags::READ_DISCONNECT); - if let Some(mut payload) = self.payload.take() { + this.flags.insert(Flags::READ_DISCONNECT); + if let Some(mut payload) = this.payload.take() { payload.set_error(PayloadError::Incomplete(None)); } return Ok(()); @@ -623,55 +662,56 @@ where } } - match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { + match Pin::new(&mut this.ka_timer.as_mut().unwrap()).poll(cx) { Poll::Ready(()) => { // if we get timeout during shutdown, drop connection - if self.flags.contains(Flags::SHUTDOWN) { + if this.flags.contains(Flags::SHUTDOWN) { return Err(DispatchError::DisconnectTimeout); - } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { + } else if this.ka_timer.as_mut().unwrap().deadline() >= *this.ka_expire { // check for any outstanding tasks - if self.state.is_empty() && self.write_buf.is_empty() { - if self.flags.contains(Flags::STARTED) { + if this.state.is_empty() && this.write_buf.is_empty() { + if this.flags.contains(Flags::STARTED) { trace!("Keep-alive timeout, close connection"); - self.flags.insert(Flags::SHUTDOWN); + this.flags.insert(Flags::SHUTDOWN); // start shutdown timer if let Some(deadline) = - self.codec.config().client_disconnect_timer() + this.codec.config().client_disconnect_timer() { - if let Some(mut timer) = self.ka_timer.as_mut() { + if let Some(mut timer) = this.ka_timer.as_mut() { timer.reset(deadline); let _ = Pin::new(&mut timer).poll(cx); } } else { // no shutdown timeout, drop socket - self.flags.insert(Flags::WRITE_DISCONNECT); + this.flags.insert(Flags::WRITE_DISCONNECT); return Ok(()); } } else { // timeout on first request (slow request) return 408 - if !self.flags.contains(Flags::STARTED) { + if !this.flags.contains(Flags::STARTED) { trace!("Slow request timeout"); - let _ = self.send_response( + let _ = self.as_mut().send_response( Response::RequestTimeout().finish().drop_body(), ResponseBody::Other(Body::Empty), ); + this = self.as_mut().project(); } else { trace!("Keep-alive connection timeout"); } - self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); - self.state = State::None; + this.flags.insert(Flags::STARTED | Flags::SHUTDOWN); + this.state.set(State::None); } } else if let Some(deadline) = - self.codec.config().keep_alive_expire() + this.codec.config().keep_alive_expire() { - if let Some(mut timer) = self.ka_timer.as_mut() { + if let Some(mut timer) = this.ka_timer.as_mut() { timer.reset(deadline); let _ = Pin::new(&mut timer).poll(cx); } } - } else if let Some(mut timer) = self.ka_timer.as_mut() { - timer.reset(self.ka_expire); + } else if let Some(mut timer) = this.ka_timer.as_mut() { + timer.reset(*this.ka_expire); let _ = Pin::new(&mut timer).poll(cx); } } @@ -696,22 +736,25 @@ where { type Output = Result<(), DispatchError>; + #[pin_project::project] #[inline] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().inner { - DispatcherState::Normal(ref mut inner) => { - inner.poll_keepalive(cx)?; + let this = self.as_mut().project(); + #[project] + match this.inner.project() { + DispatcherState::Normal(mut inner) => { + inner.as_mut().poll_keepalive(cx)?; if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::WRITE_DISCONNECT) { Poll::Ready(Ok(())) } else { // flush buffer - inner.poll_flush(cx)?; + inner.as_mut().poll_flush(cx)?; if !inner.write_buf.is_empty() { Poll::Pending } else { - match Pin::new(&mut inner.io).poll_shutdown(cx) { + match Pin::new(inner.project().io).poll_shutdown(cx) { Poll::Ready(res) => { Poll::Ready(res.map_err(DispatchError::from)) } @@ -723,33 +766,34 @@ where // read socket into a buf let should_disconnect = if !inner.flags.contains(Flags::READ_DISCONNECT) { - read_available(cx, &mut inner.io, &mut inner.read_buf)? + let mut inner_p = inner.as_mut().project(); + read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)? } else { None }; - inner.poll_request(cx)?; + inner.as_mut().poll_request(cx)?; if let Some(true) = should_disconnect { - inner.flags.insert(Flags::READ_DISCONNECT); - if let Some(mut payload) = inner.payload.take() { + let inner_p = inner.as_mut().project(); + inner_p.flags.insert(Flags::READ_DISCONNECT); + if let Some(mut payload) = inner_p.payload.take() { payload.feed_eof(); } }; loop { + let inner_p = inner.as_mut().project(); let remaining = - inner.write_buf.capacity() - inner.write_buf.len(); + inner_p.write_buf.capacity() - inner_p.write_buf.len(); if remaining < LW_BUFFER_SIZE { - inner.write_buf.reserve(HW_BUFFER_SIZE - remaining); + inner_p.write_buf.reserve(HW_BUFFER_SIZE - remaining); } - let result = inner.poll_response(cx)?; + let result = inner.as_mut().poll_response(cx)?; let drain = result == PollResponse::DrainWriteBuf; // switch to upgrade handler if let PollResponse::Upgrade(req) = result { - if let DispatcherState::Normal(inner) = - std::mem::replace(&mut self.inner, DispatcherState::None) - { + if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() { let mut parts = FramedParts::with_read_buf( inner.io, inner.codec, @@ -757,9 +801,8 @@ where ); parts.write_buf = inner.write_buf; let framed = Framed::from_parts(parts); - self.inner = DispatcherState::Upgrade( - Box::pin(inner.upgrade.unwrap().call((req, framed))), - ); + let upgrade = inner.upgrade.unwrap().call((req, framed)); + self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade)); return self.poll(cx); } else { panic!() @@ -769,7 +812,7 @@ where // we didnt get WouldBlock from write operation, // so data get written to kernel completely (OSX) // and we have to write again otherwise response can get stuck - if inner.poll_flush(cx)? || !drain { + if inner.as_mut().poll_flush(cx)? || !drain { break; } } @@ -781,25 +824,26 @@ where let is_empty = inner.state.is_empty(); + let inner_p = inner.as_mut().project(); // read half is closed and we do not processing any responses - if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty { - inner.flags.insert(Flags::SHUTDOWN); + if inner_p.flags.contains(Flags::READ_DISCONNECT) && is_empty { + inner_p.flags.insert(Flags::SHUTDOWN); } // keep-alive and stream errors - if is_empty && inner.write_buf.is_empty() { - if let Some(err) = inner.error.take() { + if is_empty && inner_p.write_buf.is_empty() { + if let Some(err) = inner_p.error.take() { Poll::Ready(Err(err)) } // disconnect if keep-alive is not enabled - else if inner.flags.contains(Flags::STARTED) - && !inner.flags.intersects(Flags::KEEPALIVE) + else if inner_p.flags.contains(Flags::STARTED) + && !inner_p.flags.intersects(Flags::KEEPALIVE) { - inner.flags.insert(Flags::SHUTDOWN); + inner_p.flags.insert(Flags::SHUTDOWN); self.poll(cx) } // disconnect if shutdown - else if inner.flags.contains(Flags::SHUTDOWN) { + else if inner_p.flags.contains(Flags::SHUTDOWN) { self.poll(cx) } else { Poll::Pending diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index be6a42793..89013129a 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -36,7 +36,7 @@ where impl Future for SendResponse where T: AsyncRead + AsyncWrite, - B: MessageBody, + B: MessageBody + Unpin, { type Output = Result, Error>;