diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 6e226a30d..043271cb5 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -71,7 +71,6 @@ where { Normal(#[pin] InnerDispatcher), Upgrade(#[pin] U::Future), - None, } #[pin_project] @@ -101,7 +100,7 @@ where ka_expire: Instant, ka_timer: Option, - io: T, + io: Option, read_buf: BytesMut, write_buf: BytesMut, codec: Codec, @@ -148,22 +147,6 @@ where } } } - -impl DispatcherState -where - S: Service, - S::Error: Into, - B: MessageBody, - X: Service, - X::Error: Into, - U: Service), Response = ()>, - U::Error: fmt::Display, -{ - fn take(self: Pin<&mut Self>) -> Self { - std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None) - } -} - enum PollResponse { Upgrade(Request), DoNothing, @@ -258,7 +241,7 @@ where state: State::None, error: None, messages: VecDeque::new(), - io, + io: Some(io), codec, read_buf, service, @@ -322,9 +305,10 @@ where let len = self.write_buf.len(); let mut written = 0; #[project] - let InnerDispatcher { mut io, write_buf, .. } = self.project(); + let InnerDispatcher { io, write_buf, .. } = self.project(); + let mut io = Pin::new(io.as_mut().unwrap()); while written < len { - match Pin::new(&mut io).poll_write(cx, &write_buf[written..]) + match io.as_mut().poll_write(cx, &write_buf[written..]) { Poll::Ready(Ok(0)) => { return Err(DispatchError::Io(io::Error::new( @@ -751,10 +735,10 @@ where } else { // flush buffer inner.as_mut().poll_flush(cx)?; - if !inner.write_buf.is_empty() { + if !inner.write_buf.is_empty() || inner.io.is_none() { Poll::Pending } else { - match Pin::new(inner.project().io).poll_shutdown(cx) { + match Pin::new(inner.project().io).as_pin_mut().unwrap().poll_shutdown(cx) { Poll::Ready(res) => { Poll::Ready(res.map_err(DispatchError::from)) } @@ -767,7 +751,7 @@ where let should_disconnect = if !inner.flags.contains(Flags::READ_DISCONNECT) { let mut inner_p = inner.as_mut().project(); - read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)? + read_available(cx, inner_p.io.as_mut().unwrap(), &mut inner_p.read_buf)? } else { None }; @@ -793,20 +777,17 @@ where // switch to upgrade handler if let PollResponse::Upgrade(req) = result { - if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() { - let mut parts = FramedParts::with_read_buf( - inner.io, - inner.codec, - inner.read_buf, - ); - parts.write_buf = inner.write_buf; - let framed = Framed::from_parts(parts); - let upgrade = inner.upgrade.unwrap().call((req, framed)); - self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade)); - return self.poll(cx); - } else { - panic!() - } + let inner_p = inner.as_mut().project(); + let mut parts = FramedParts::with_read_buf( + inner_p.io.take().unwrap(), + std::mem::take(inner_p.codec), + std::mem::take(inner_p.read_buf), + ); + parts.write_buf = std::mem::take(inner_p.write_buf); + let framed = Framed::from_parts(parts); + let upgrade = inner_p.upgrade.take().unwrap().call((req, framed)); + self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade)); + return self.poll(cx); } // we didnt get WouldBlock from write operation, @@ -859,7 +840,6 @@ where DispatchError::Upgrade }) } - DispatcherState::None => panic!(), } } } @@ -949,9 +929,9 @@ mod tests { Poll::Ready(res) => assert!(res.is_err()), } - if let DispatcherState::Normal(ref inner) = h1.inner { + if let DispatcherState::Normal(ref mut inner) = h1.inner { assert!(inner.flags.contains(Flags::READ_DISCONNECT)); - assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n"); + assert_eq!(&inner.io.take().unwrap().write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n"); } }) .await;