1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-06-02 13:29:24 +00:00

refactor dispatcher to avoid possible UB with DispatcherState Pin

This commit is contained in:
Maksym Vorobiov 2020-02-10 13:17:38 +02:00 committed by Yuki Okushi
parent 69dab0063c
commit c05f9475c5

View file

@ -71,7 +71,6 @@ where
{
Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
Upgrade(#[pin] U::Future),
None,
}
#[pin_project]
@ -101,7 +100,7 @@ where
ka_expire: Instant,
ka_timer: Option<Delay>,
io: T,
io: Option<T>,
read_buf: BytesMut,
write_buf: BytesMut,
codec: Codec,
@ -148,22 +147,6 @@ where
}
}
}
impl<T, S, B, X, U> DispatcherState<T, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: Into<Error>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: Into<Error>,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
{
fn take(self: Pin<&mut Self>) -> Self {
std::mem::replace(unsafe { self.get_unchecked_mut() }, Self::None)
}
}
enum PollResponse {
Upgrade(Request),
DoNothing,
@ -258,7 +241,7 @@ where
state: State::None,
error: None,
messages: VecDeque::new(),
io,
io: Some(io),
codec,
read_buf,
service,
@ -322,9 +305,10 @@ where
let len = self.write_buf.len();
let mut written = 0;
#[project]
let InnerDispatcher { mut io, write_buf, .. } = self.project();
let InnerDispatcher { io, write_buf, .. } = self.project();
let mut io = Pin::new(io.as_mut().unwrap());
while written < len {
match Pin::new(&mut io).poll_write(cx, &write_buf[written..])
match io.as_mut().poll_write(cx, &write_buf[written..])
{
Poll::Ready(Ok(0)) => {
return Err(DispatchError::Io(io::Error::new(
@ -751,10 +735,10 @@ where
} else {
// flush buffer
inner.as_mut().poll_flush(cx)?;
if !inner.write_buf.is_empty() {
if !inner.write_buf.is_empty() || inner.io.is_none() {
Poll::Pending
} else {
match Pin::new(inner.project().io).poll_shutdown(cx) {
match Pin::new(inner.project().io).as_pin_mut().unwrap().poll_shutdown(cx) {
Poll::Ready(res) => {
Poll::Ready(res.map_err(DispatchError::from))
}
@ -767,7 +751,7 @@ where
let should_disconnect =
if !inner.flags.contains(Flags::READ_DISCONNECT) {
let mut inner_p = inner.as_mut().project();
read_available(cx, &mut inner_p.io, &mut inner_p.read_buf)?
read_available(cx, inner_p.io.as_mut().unwrap(), &mut inner_p.read_buf)?
} else {
None
};
@ -793,20 +777,17 @@ where
// switch to upgrade handler
if let PollResponse::Upgrade(req) = result {
if let DispatcherState::Normal(inner) = self.as_mut().project().inner.take() {
let mut parts = FramedParts::with_read_buf(
inner.io,
inner.codec,
inner.read_buf,
);
parts.write_buf = inner.write_buf;
let framed = Framed::from_parts(parts);
let upgrade = inner.upgrade.unwrap().call((req, framed));
self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade));
return self.poll(cx);
} else {
panic!()
}
let inner_p = inner.as_mut().project();
let mut parts = FramedParts::with_read_buf(
inner_p.io.take().unwrap(),
std::mem::take(inner_p.codec),
std::mem::take(inner_p.read_buf),
);
parts.write_buf = std::mem::take(inner_p.write_buf);
let framed = Framed::from_parts(parts);
let upgrade = inner_p.upgrade.take().unwrap().call((req, framed));
self.as_mut().project().inner.set(DispatcherState::Upgrade(upgrade));
return self.poll(cx);
}
// we didnt get WouldBlock from write operation,
@ -859,7 +840,6 @@ where
DispatchError::Upgrade
})
}
DispatcherState::None => panic!(),
}
}
}
@ -949,9 +929,9 @@ mod tests {
Poll::Ready(res) => assert!(res.is_err()),
}
if let DispatcherState::Normal(ref inner) = h1.inner {
if let DispatcherState::Normal(ref mut inner) = h1.inner {
assert!(inner.flags.contains(Flags::READ_DISCONNECT));
assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
assert_eq!(&inner.io.take().unwrap().write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
}
})
.await;