1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-24 16:08:06 +00:00

unlink MessageBody from Unpin

This commit is contained in:
Maksym Vorobiov 2020-02-03 22:55:49 +02:00 committed by Yuki Okushi
parent 2e2ea7ab80
commit ec5c779732
3 changed files with 181 additions and 139 deletions

View file

@ -33,7 +33,7 @@ impl BodySize {
} }
/// Type that provides this trait can be streamed to a peer. /// Type that provides this trait can be streamed to a peer.
pub trait MessageBody: Unpin { pub trait MessageBody {
fn size(&self) -> BodySize; fn size(&self) -> BodySize;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>>;
@ -53,14 +53,13 @@ impl MessageBody for () {
} }
} }
impl<T: MessageBody> MessageBody for Box<T> { impl<T: MessageBody + Unpin> MessageBody for Box<T> {
fn size(&self) -> BodySize { fn size(&self) -> BodySize {
self.as_ref().size() self.as_ref().size()
} }
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
let a: Pin<&mut T> = Pin::new(self.get_mut().as_mut()); unsafe { self.map_unchecked_mut(|boxed| boxed.as_mut()) }.poll_next(cx)
a.poll_next(cx)
} }
} }
@ -70,8 +69,7 @@ impl MessageBody for Box<dyn MessageBody> {
} }
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, Error>>> {
let a: Pin<&mut dyn MessageBody> = Pin::new(self.get_mut().as_mut()); unsafe { Pin::new_unchecked(self.get_mut().as_mut()) }.poll_next(cx)
a.poll_next(cx)
} }
} }

View file

@ -10,6 +10,7 @@ use actix_service::Service;
use bitflags::bitflags; use bitflags::bitflags;
use bytes::{Buf, BytesMut}; use bytes::{Buf, BytesMut};
use log::{error, trace}; use log::{error, trace};
use pin_project::pin_project;
use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
@ -41,6 +42,7 @@ bitflags! {
} }
} }
#[pin_project::pin_project]
/// Dispatcher for HTTP/1.1 protocol /// Dispatcher for HTTP/1.1 protocol
pub struct Dispatcher<T, S, B, X, U> pub struct Dispatcher<T, S, B, X, U>
where where
@ -52,9 +54,11 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
#[pin]
inner: DispatcherState<T, S, B, X, U>, inner: DispatcherState<T, S, B, X, U>,
} }
#[pin_project]
enum DispatcherState<T, S, B, X, U> enum DispatcherState<T, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
@ -65,11 +69,12 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
Normal(InnerDispatcher<T, S, B, X, U>), Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
Upgrade(Pin<Box<U::Future>>), Upgrade(#[pin] U::Future),
None, None,
} }
#[pin_project]
struct InnerDispatcher<T, S, B, X, U> struct InnerDispatcher<T, S, B, X, U>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
@ -88,6 +93,7 @@ where
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
error: Option<DispatchError>, error: Option<DispatchError>,
#[pin]
state: State<S, B, X>, state: State<S, B, X>,
payload: Option<PayloadSender>, payload: Option<PayloadSender>,
messages: VecDeque<DispatcherMessage>, messages: VecDeque<DispatcherMessage>,
@ -107,6 +113,7 @@ enum DispatcherMessage {
Error(Response<()>), Error(Response<()>),
} }
#[pin_project]
enum State<S, B, X> enum State<S, B, X>
where where
S: Service<Request = Request>, S: Service<Request = Request>,
@ -114,9 +121,9 @@ where
B: MessageBody, B: MessageBody,
{ {
None, None,
ExpectCall(Pin<Box<X::Future>>), ExpectCall(#[pin] X::Future),
ServiceCall(Pin<Box<S::Future>>), ServiceCall(#[pin] S::Future),
SendPayload(ResponseBody<B>), SendPayload(#[pin] ResponseBody<B>),
} }
impl<S, B, X> State<S, B, X> impl<S, B, X> State<S, B, X>
@ -142,6 +149,21 @@ where
} }
} }
impl<T, S, B, X, U> DispatcherState<T, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: Into<Error>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
fn take(self: Pin<&mut Self>) -> Self {
std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None)
}
}
enum PollResponse { enum PollResponse {
Upgrade(Request), Upgrade(Request),
DoNothing, DoNothing,
@ -278,10 +300,11 @@ where
} }
// if checked is set to true, delay disconnect until all tasks have finished. // if checked is set to true, delay disconnect until all tasks have finished.
fn client_disconnected(&mut self) { fn client_disconnected(self: Pin<&mut Self>) {
self.flags let this = self.project();
this.flags
.insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT); .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = this.payload.take() {
payload.set_error(PayloadError::Incomplete(None)); payload.set_error(PayloadError::Incomplete(None));
} }
} }
@ -290,16 +313,18 @@ where
/// ///
/// true - got whouldblock /// true - got whouldblock
/// false - didnt get whouldblock /// false - didnt get whouldblock
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<bool, DispatchError> { #[pin_project::project]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
if self.write_buf.is_empty() { if self.write_buf.is_empty() {
return Ok(false); return Ok(false);
} }
let len = self.write_buf.len(); let len = self.write_buf.len();
let mut written = 0; let mut written = 0;
#[project]
let InnerDispatcher { mut io, write_buf, .. } = self.project();
while written < len { while written < len {
match Pin::new(&mut self.io) match Pin::new(&mut io).poll_write(cx, &write_buf[written..])
.poll_write(cx, &self.write_buf[written..])
{ {
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
return Err(DispatchError::Io(io::Error::new( return Err(DispatchError::Io(io::Error::new(
@ -312,113 +337,120 @@ where
} }
Poll::Pending => { Poll::Pending => {
if written > 0 { if written > 0 {
self.write_buf.advance(written); write_buf.advance(written);
} }
return Ok(true); return Ok(true);
} }
Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)), Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)),
} }
} }
if written == self.write_buf.len() { if written == write_buf.len() {
unsafe { self.write_buf.set_len(0) } unsafe { write_buf.set_len(0) }
} else { } else {
self.write_buf.advance(written); write_buf.advance(written);
} }
Ok(false) Ok(false)
} }
fn send_response( fn send_response(
&mut self, self: Pin<&mut Self>,
message: Response<()>, message: Response<()>,
body: ResponseBody<B>, body: ResponseBody<B>,
) -> Result<State<S, B, X>, DispatchError> { ) -> Result<State<S, B, X>, DispatchError> {
self.codec let mut this = self.project();
.encode(Message::Item((message, body.size())), &mut self.write_buf) this.codec
.encode(Message::Item((message, body.size())), &mut this.write_buf)
.map_err(|err| { .map_err(|err| {
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = this.payload.take() {
payload.set_error(PayloadError::Incomplete(None)); payload.set_error(PayloadError::Incomplete(None));
} }
DispatchError::Io(err) DispatchError::Io(err)
})?; })?;
self.flags.set(Flags::KEEPALIVE, self.codec.keepalive()); this.flags.set(Flags::KEEPALIVE, this.codec.keepalive());
match body.size() { match body.size() {
BodySize::None | BodySize::Empty => Ok(State::None), BodySize::None | BodySize::Empty => Ok(State::None),
_ => Ok(State::SendPayload(body)), _ => Ok(State::SendPayload(body)),
} }
} }
fn send_continue(&mut self) { fn send_continue(self: Pin<&mut Self>) {
self.write_buf self.project().write_buf
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
} }
#[pin_project::project]
fn poll_response( fn poll_response(
&mut self, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Result<PollResponse, DispatchError> { ) -> Result<PollResponse, DispatchError> {
loop { loop {
let state = match self.state { let mut this = self.as_mut().project();
State::None => match self.messages.pop_front() { #[project]
let state = match this.state.project() {
State::None => match this.messages.pop_front() {
Some(DispatcherMessage::Item(req)) => { Some(DispatcherMessage::Item(req)) => {
Some(self.handle_request(req, cx)?) Some(self.as_mut().handle_request(req, cx)?)
} }
Some(DispatcherMessage::Error(res)) => { Some(DispatcherMessage::Error(res)) => {
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?) Some(self.as_mut().send_response(res, ResponseBody::Other(Body::Empty))?)
} }
Some(DispatcherMessage::Upgrade(req)) => { Some(DispatcherMessage::Upgrade(req)) => {
return Ok(PollResponse::Upgrade(req)); return Ok(PollResponse::Upgrade(req));
} }
None => None, None => None,
}, },
State::ExpectCall(ref mut fut) => { State::ExpectCall(fut) => {
match fut.as_mut().poll(cx) { match fut.poll(cx) {
Poll::Ready(Ok(req)) => { Poll::Ready(Ok(req)) => {
self.send_continue(); self.as_mut().send_continue();
self.state = State::ServiceCall(Box::pin(self.service.call(req))); this = self.as_mut().project();
this.state.set(State::ServiceCall(this.service.call(req)));
continue; continue;
} }
Poll::Ready(Err(e)) => { Poll::Ready(Err(e)) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?) Some(self.as_mut().send_response(res, body.into_body())?)
} }
Poll::Pending => None, Poll::Pending => None,
} }
} }
State::ServiceCall(ref mut fut) => { State::ServiceCall(fut) => {
match fut.as_mut().poll(cx) { match fut.poll(cx) {
Poll::Ready(Ok(res)) => { Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
self.state = self.send_response(res, body)?; let state = self.as_mut().send_response(res, body)?;
this = self.as_mut().project();
this.state.set(state);
continue; continue;
} }
Poll::Ready(Err(e)) => { Poll::Ready(Err(e)) => {
let res: Response = e.into().into(); let res: Response = e.into().into();
let (res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
Some(self.send_response(res, body.into_body())?) Some(self.as_mut().send_response(res, body.into_body())?)
} }
Poll::Pending => None, Poll::Pending => None,
} }
} }
State::SendPayload(ref mut stream) => { State::SendPayload(mut stream) => {
let mut stream = Pin::new(stream);
loop { loop {
if self.write_buf.len() < HW_BUFFER_SIZE { if this.write_buf.len() < HW_BUFFER_SIZE {
match stream.as_mut().poll_next(cx) { match stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(item))) => { Poll::Ready(Some(Ok(item))) => {
self.codec.encode( this.codec.encode(
Message::Chunk(Some(item)), Message::Chunk(Some(item)),
&mut self.write_buf, &mut this.write_buf,
)?; )?;
continue; continue;
} }
Poll::Ready(None) => { Poll::Ready(None) => {
self.codec.encode( this.codec.encode(
Message::Chunk(None), Message::Chunk(None),
&mut self.write_buf, &mut this.write_buf,
)?; )?;
self.state = State::None; this = self.as_mut().project();
this.state.set(State::None);
} }
Poll::Ready(Some(Err(_))) => { Poll::Ready(Some(Err(_))) => {
return Err(DispatchError::Unknown) return Err(DispatchError::Unknown)
@ -434,9 +466,11 @@ where
} }
}; };
this = self.as_mut().project();
// set new state // set new state
if let Some(state) = state { if let Some(state) = state {
self.state = state; this.state.set(state);
if !self.state.is_empty() { if !self.state.is_empty() {
continue; continue;
} }
@ -444,7 +478,7 @@ where
// if read-backpressure is enabled and we consumed some data. // if read-backpressure is enabled and we consumed some data.
// we may read more data and retry // we may read more data and retry
if self.state.is_call() { if self.state.is_call() {
if self.poll_request(cx)? { if self.as_mut().poll_request(cx)? {
continue; continue;
} }
} else if !self.messages.is_empty() { } else if !self.messages.is_empty() {
@ -458,16 +492,16 @@ where
} }
fn handle_request( fn handle_request(
&mut self, mut self: Pin<&mut Self>,
req: Request, req: Request,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Result<State<S, B, X>, DispatchError> { ) -> Result<State<S, B, X>, DispatchError> {
// Handle `EXPECT: 100-Continue` header // Handle `EXPECT: 100-Continue` header
let req = if req.head().expect() { let req = if req.head().expect() {
let mut task = Box::pin(self.expect.call(req)); let mut task = self.as_mut().project().expect.call(req);
match task.as_mut().poll(cx) { match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
Poll::Ready(Ok(req)) => { Poll::Ready(Ok(req)) => {
self.send_continue(); self.as_mut().send_continue();
req req
} }
Poll::Pending => return Ok(State::ExpectCall(task)), Poll::Pending => return Ok(State::ExpectCall(task)),
@ -483,8 +517,8 @@ where
}; };
// Call service // Call service
let mut task = Box::pin(self.service.call(req)); let mut task = self.as_mut().project().service.call(req);
match task.as_mut().poll(cx) { match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
Poll::Ready(Ok(res)) => { Poll::Ready(Ok(res)) => {
let (res, body) = res.into().replace_body(()); let (res, body) = res.into().replace_body(());
self.send_response(res, body) self.send_response(res, body)
@ -500,7 +534,7 @@ where
/// Process one incoming requests /// Process one incoming requests
pub(self) fn poll_request( pub(self) fn poll_request(
&mut self, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Result<bool, DispatchError> { ) -> Result<bool, DispatchError> {
// limit a mount of non processed requests // limit a mount of non processed requests
@ -509,24 +543,25 @@ where
} }
let mut updated = false; let mut updated = false;
let mut this = self.as_mut().project();
loop { loop {
match self.codec.decode(&mut self.read_buf) { match this.codec.decode(&mut this.read_buf) {
Ok(Some(msg)) => { Ok(Some(msg)) => {
updated = true; updated = true;
self.flags.insert(Flags::STARTED); this.flags.insert(Flags::STARTED);
match msg { match msg {
Message::Item(mut req) => { Message::Item(mut req) => {
let pl = self.codec.message_type(); let pl = this.codec.message_type();
req.head_mut().peer_addr = self.peer_addr; req.head_mut().peer_addr = *this.peer_addr;
// set on_connect data // set on_connect data
if let Some(ref on_connect) = self.on_connect { if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut()); on_connect.set(&mut req.extensions_mut());
} }
if pl == MessageType::Stream && self.upgrade.is_some() { if pl == MessageType::Stream && this.upgrade.is_some() {
self.messages.push_back(DispatcherMessage::Upgrade(req)); this.messages.push_back(DispatcherMessage::Upgrade(req));
break; break;
} }
if pl == MessageType::Payload || pl == MessageType::Stream { if pl == MessageType::Payload || pl == MessageType::Stream {
@ -534,41 +569,43 @@ where
let (req1, _) = let (req1, _) =
req.replace_payload(crate::Payload::H1(pl)); req.replace_payload(crate::Payload::H1(pl));
req = req1; req = req1;
self.payload = Some(ps); *this.payload = Some(ps);
} }
// handle request early // handle request early
if self.state.is_empty() { if this.state.is_empty() {
self.state = self.handle_request(req, cx)?; let state = self.as_mut().handle_request(req, cx)?;
this = self.as_mut().project();
this.state.set(state);
} else { } else {
self.messages.push_back(DispatcherMessage::Item(req)); this.messages.push_back(DispatcherMessage::Item(req));
} }
} }
Message::Chunk(Some(chunk)) => { Message::Chunk(Some(chunk)) => {
if let Some(ref mut payload) = self.payload { if let Some(ref mut payload) = this.payload {
payload.feed_data(chunk); payload.feed_data(chunk);
} else { } else {
error!( error!(
"Internal server error: unexpected payload chunk" "Internal server error: unexpected payload chunk"
); );
self.flags.insert(Flags::READ_DISCONNECT); this.flags.insert(Flags::READ_DISCONNECT);
self.messages.push_back(DispatcherMessage::Error( this.messages.push_back(DispatcherMessage::Error(
Response::InternalServerError().finish().drop_body(), Response::InternalServerError().finish().drop_body(),
)); ));
self.error = Some(DispatchError::InternalError); *this.error = Some(DispatchError::InternalError);
break; break;
} }
} }
Message::Chunk(None) => { Message::Chunk(None) => {
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = this.payload.take() {
payload.feed_eof(); payload.feed_eof();
} else { } else {
error!("Internal server error: unexpected eof"); error!("Internal server error: unexpected eof");
self.flags.insert(Flags::READ_DISCONNECT); this.flags.insert(Flags::READ_DISCONNECT);
self.messages.push_back(DispatcherMessage::Error( this.messages.push_back(DispatcherMessage::Error(
Response::InternalServerError().finish().drop_body(), Response::InternalServerError().finish().drop_body(),
)); ));
self.error = Some(DispatchError::InternalError); *this.error = Some(DispatchError::InternalError);
break; break;
} }
} }
@ -576,44 +613,46 @@ where
} }
Ok(None) => break, Ok(None) => break,
Err(ParseError::Io(e)) => { Err(ParseError::Io(e)) => {
self.client_disconnected(); self.as_mut().client_disconnected();
self.error = Some(DispatchError::Io(e)); this = self.as_mut().project();
*this.error = Some(DispatchError::Io(e));
break; break;
} }
Err(e) => { Err(e) => {
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = this.payload.take() {
payload.set_error(PayloadError::EncodingCorrupted); payload.set_error(PayloadError::EncodingCorrupted);
} }
// Malformed requests should be responded with 400 // Malformed requests should be responded with 400
self.messages.push_back(DispatcherMessage::Error( this.messages.push_back(DispatcherMessage::Error(
Response::BadRequest().finish().drop_body(), Response::BadRequest().finish().drop_body(),
)); ));
self.flags.insert(Flags::READ_DISCONNECT); this.flags.insert(Flags::READ_DISCONNECT);
self.error = Some(e.into()); *this.error = Some(e.into());
break; break;
} }
} }
} }
if updated && self.ka_timer.is_some() { if updated && this.ka_timer.is_some() {
if let Some(expire) = self.codec.config().keep_alive_expire() { if let Some(expire) = this.codec.config().keep_alive_expire() {
self.ka_expire = expire; *this.ka_expire = expire;
} }
} }
Ok(updated) Ok(updated)
} }
/// keep-alive timer /// keep-alive timer
fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> { fn poll_keepalive(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
if self.ka_timer.is_none() { let mut this = self.as_mut().project();
if this.ka_timer.is_none() {
// shutdown timeout // shutdown timeout
if self.flags.contains(Flags::SHUTDOWN) { if this.flags.contains(Flags::SHUTDOWN) {
if let Some(interval) = self.codec.config().client_disconnect_timer() { if let Some(interval) = this.codec.config().client_disconnect_timer() {
self.ka_timer = Some(delay_until(interval)); *this.ka_timer = Some(delay_until(interval));
} else { } else {
self.flags.insert(Flags::READ_DISCONNECT); this.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = this.payload.take() {
payload.set_error(PayloadError::Incomplete(None)); payload.set_error(PayloadError::Incomplete(None));
} }
return Ok(()); return Ok(());
@ -623,55 +662,56 @@ where
} }
} }
match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) { match Pin::new(&mut this.ka_timer.as_mut().unwrap()).poll(cx) {
Poll::Ready(()) => { Poll::Ready(()) => {
// if we get timeout during shutdown, drop connection // if we get timeout during shutdown, drop connection
if self.flags.contains(Flags::SHUTDOWN) { if this.flags.contains(Flags::SHUTDOWN) {
return Err(DispatchError::DisconnectTimeout); return Err(DispatchError::DisconnectTimeout);
} else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { } else if this.ka_timer.as_mut().unwrap().deadline() >= *this.ka_expire {
// check for any outstanding tasks // check for any outstanding tasks
if self.state.is_empty() && self.write_buf.is_empty() { if this.state.is_empty() && this.write_buf.is_empty() {
if self.flags.contains(Flags::STARTED) { if this.flags.contains(Flags::STARTED) {
trace!("Keep-alive timeout, close connection"); trace!("Keep-alive timeout, close connection");
self.flags.insert(Flags::SHUTDOWN); this.flags.insert(Flags::SHUTDOWN);
// start shutdown timer // start shutdown timer
if let Some(deadline) = if let Some(deadline) =
self.codec.config().client_disconnect_timer() this.codec.config().client_disconnect_timer()
{ {
if let Some(mut timer) = self.ka_timer.as_mut() { if let Some(mut timer) = this.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = Pin::new(&mut timer).poll(cx); let _ = Pin::new(&mut timer).poll(cx);
} }
} else { } else {
// no shutdown timeout, drop socket // no shutdown timeout, drop socket
self.flags.insert(Flags::WRITE_DISCONNECT); this.flags.insert(Flags::WRITE_DISCONNECT);
return Ok(()); return Ok(());
} }
} else { } else {
// timeout on first request (slow request) return 408 // timeout on first request (slow request) return 408
if !self.flags.contains(Flags::STARTED) { if !this.flags.contains(Flags::STARTED) {
trace!("Slow request timeout"); trace!("Slow request timeout");
let _ = self.send_response( let _ = self.as_mut().send_response(
Response::RequestTimeout().finish().drop_body(), Response::RequestTimeout().finish().drop_body(),
ResponseBody::Other(Body::Empty), ResponseBody::Other(Body::Empty),
); );
this = self.as_mut().project();
} else { } else {
trace!("Keep-alive connection timeout"); trace!("Keep-alive connection timeout");
} }
self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); this.flags.insert(Flags::STARTED | Flags::SHUTDOWN);
self.state = State::None; this.state.set(State::None);
} }
} else if let Some(deadline) = } else if let Some(deadline) =
self.codec.config().keep_alive_expire() this.codec.config().keep_alive_expire()
{ {
if let Some(mut timer) = self.ka_timer.as_mut() { if let Some(mut timer) = this.ka_timer.as_mut() {
timer.reset(deadline); timer.reset(deadline);
let _ = Pin::new(&mut timer).poll(cx); let _ = Pin::new(&mut timer).poll(cx);
} }
} }
} else if let Some(mut timer) = self.ka_timer.as_mut() { } else if let Some(mut timer) = this.ka_timer.as_mut() {
timer.reset(self.ka_expire); timer.reset(*this.ka_expire);
let _ = Pin::new(&mut timer).poll(cx); let _ = Pin::new(&mut timer).poll(cx);
} }
} }
@ -696,22 +736,25 @@ where
{ {
type Output = Result<(), DispatchError>; type Output = Result<(), DispatchError>;
#[pin_project::project]
#[inline] #[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().inner { let this = self.as_mut().project();
DispatcherState::Normal(ref mut inner) => { #[project]
inner.poll_keepalive(cx)?; match this.inner.project() {
DispatcherState::Normal(mut inner) => {
inner.as_mut().poll_keepalive(cx)?;
if inner.flags.contains(Flags::SHUTDOWN) { if inner.flags.contains(Flags::SHUTDOWN) {
if inner.flags.contains(Flags::WRITE_DISCONNECT) { if inner.flags.contains(Flags::WRITE_DISCONNECT) {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} else { } else {
// flush buffer // flush buffer
inner.poll_flush(cx)?; inner.as_mut().poll_flush(cx)?;
if !inner.write_buf.is_empty() { if !inner.write_buf.is_empty() {
Poll::Pending Poll::Pending
} else { } else {
match Pin::new(&mut inner.io).poll_shutdown(cx) { match Pin::new(inner.project().io).poll_shutdown(cx) {
Poll::Ready(res) => { Poll::Ready(res) => {
Poll::Ready(res.map_err(DispatchError::from)) Poll::Ready(res.map_err(DispatchError::from))
} }
@ -723,33 +766,34 @@ where
// read socket into a buf // read socket into a buf
let should_disconnect = let should_disconnect =
if !inner.flags.contains(Flags::READ_DISCONNECT) { if !inner.flags.contains(Flags::READ_DISCONNECT) {
read_available(cx, &mut inner.io, &mut inner.read_buf)? let mut inner_p = inner.as_mut().project();
read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)?
} else { } else {
None None
}; };
inner.poll_request(cx)?; inner.as_mut().poll_request(cx)?;
if let Some(true) = should_disconnect { if let Some(true) = should_disconnect {
inner.flags.insert(Flags::READ_DISCONNECT); let inner_p = inner.as_mut().project();
if let Some(mut payload) = inner.payload.take() { inner_p.flags.insert(Flags::READ_DISCONNECT);
if let Some(mut payload) = inner_p.payload.take() {
payload.feed_eof(); payload.feed_eof();
} }
}; };
loop { loop {
let inner_p = inner.as_mut().project();
let remaining = let remaining =
inner.write_buf.capacity() - inner.write_buf.len(); inner_p.write_buf.capacity() - inner_p.write_buf.len();
if remaining < LW_BUFFER_SIZE { if remaining < LW_BUFFER_SIZE {
inner.write_buf.reserve(HW_BUFFER_SIZE - remaining); inner_p.write_buf.reserve(HW_BUFFER_SIZE - remaining);
} }
let result = inner.poll_response(cx)?; let result = inner.as_mut().poll_response(cx)?;
let drain = result == PollResponse::DrainWriteBuf; let drain = result == PollResponse::DrainWriteBuf;
// switch to upgrade handler // switch to upgrade handler
if let PollResponse::Upgrade(req) = result { if let PollResponse::Upgrade(req) = result {
if let DispatcherState::Normal(inner) = if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() {
std::mem::replace(&mut self.inner, DispatcherState::None)
{
let mut parts = FramedParts::with_read_buf( let mut parts = FramedParts::with_read_buf(
inner.io, inner.io,
inner.codec, inner.codec,
@ -757,9 +801,8 @@ where
); );
parts.write_buf = inner.write_buf; parts.write_buf = inner.write_buf;
let framed = Framed::from_parts(parts); let framed = Framed::from_parts(parts);
self.inner = DispatcherState::Upgrade( let upgrade = inner.upgrade.unwrap().call((req, framed));
Box::pin(inner.upgrade.unwrap().call((req, framed))), self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade));
);
return self.poll(cx); return self.poll(cx);
} else { } else {
panic!() panic!()
@ -769,7 +812,7 @@ where
// we didnt get WouldBlock from write operation, // we didnt get WouldBlock from write operation,
// so data get written to kernel completely (OSX) // so data get written to kernel completely (OSX)
// and we have to write again otherwise response can get stuck // and we have to write again otherwise response can get stuck
if inner.poll_flush(cx)? || !drain { if inner.as_mut().poll_flush(cx)? || !drain {
break; break;
} }
} }
@ -781,25 +824,26 @@ where
let is_empty = inner.state.is_empty(); let is_empty = inner.state.is_empty();
let inner_p = inner.as_mut().project();
// read half is closed and we do not processing any responses // read half is closed and we do not processing any responses
if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty { if inner_p.flags.contains(Flags::READ_DISCONNECT) && is_empty {
inner.flags.insert(Flags::SHUTDOWN); inner_p.flags.insert(Flags::SHUTDOWN);
} }
// keep-alive and stream errors // keep-alive and stream errors
if is_empty && inner.write_buf.is_empty() { if is_empty && inner_p.write_buf.is_empty() {
if let Some(err) = inner.error.take() { if let Some(err) = inner_p.error.take() {
Poll::Ready(Err(err)) Poll::Ready(Err(err))
} }
// disconnect if keep-alive is not enabled // disconnect if keep-alive is not enabled
else if inner.flags.contains(Flags::STARTED) else if inner_p.flags.contains(Flags::STARTED)
&& !inner.flags.intersects(Flags::KEEPALIVE) && !inner_p.flags.intersects(Flags::KEEPALIVE)
{ {
inner.flags.insert(Flags::SHUTDOWN); inner_p.flags.insert(Flags::SHUTDOWN);
self.poll(cx) self.poll(cx)
} }
// disconnect if shutdown // disconnect if shutdown
else if inner.flags.contains(Flags::SHUTDOWN) { else if inner_p.flags.contains(Flags::SHUTDOWN) {
self.poll(cx) self.poll(cx)
} else { } else {
Poll::Pending Poll::Pending

View file

@ -36,7 +36,7 @@ where
impl<T, B> Future for SendResponse<T, B> impl<T, B> Future for SendResponse<T, B>
where where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
B: MessageBody, B: MessageBody + Unpin,
{ {
type Output = Result<Framed<T, Codec>, Error>; type Output = Result<Framed<T, Codec>, Error>;