use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_service::{IntoService, Service}; use pin_project_lite::pin_project; use super::{Codec, Frame, Message}; pin_project! { pub struct Dispatcher where S: Service, S: 'static, T: AsyncRead, T: AsyncWrite, { #[pin] inner: inner::Dispatcher, } } impl Dispatcher where T: AsyncRead + AsyncWrite, S: Service, S::Future: 'static, S::Error: 'static, { pub fn new>(io: T, service: F) -> Self { Dispatcher { inner: inner::Dispatcher::new(Framed::new(io, Codec::new()), service), } } pub fn with>(framed: Framed, service: F) -> Self { Dispatcher { inner: inner::Dispatcher::new(framed, service), } } } impl Future for Dispatcher where T: AsyncRead + AsyncWrite, S: Service, S::Future: 'static, S::Error: 'static, { type Output = Result<(), inner::DispatcherError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().inner.poll(cx) } } /// Framed dispatcher service and related utilities. mod inner { // allow dead code since this mod was ripped from actix-utils #![allow(dead_code)] use core::{ fmt, future::Future, mem, pin::Pin, task::{Context, Poll}, }; use actix_codec::Framed; use actix_service::{IntoService, Service}; use futures_core::stream::Stream; use local_channel::mpsc; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::{Decoder, Encoder}; use tracing::debug; use crate::{body::BoxBody, Response}; /// Framed transport errors pub enum DispatcherError where U: Encoder + Decoder, { /// Inner service error. Service(E), /// Frame encoding error. Encoder(>::Error), /// Frame decoding error. Decoder(::Error), } impl From for DispatcherError where U: Encoder + Decoder, { fn from(err: E) -> Self { DispatcherError::Service(err) } } impl fmt::Debug for DispatcherError where E: fmt::Debug, U: Encoder + Decoder, >::Error: fmt::Debug, ::Error: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { DispatcherError::Service(ref e) => { write!(fmt, "DispatcherError::Service({:?})", e) } DispatcherError::Encoder(ref e) => { write!(fmt, "DispatcherError::Encoder({:?})", e) } DispatcherError::Decoder(ref e) => { write!(fmt, "DispatcherError::Decoder({:?})", e) } } } } impl fmt::Display for DispatcherError where E: fmt::Display, U: Encoder + Decoder, >::Error: fmt::Debug, ::Error: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { DispatcherError::Service(ref e) => write!(fmt, "{}", e), DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e), DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e), } } } impl From> for Response where E: fmt::Debug + fmt::Display, U: Encoder + Decoder, >::Error: fmt::Debug, ::Error: fmt::Debug, { fn from(err: DispatcherError) -> Self { Response::internal_server_error().set_body(BoxBody::new(err.to_string())) } } /// Message type wrapper for signalling end of message stream. pub enum Message { /// Message item. Item(T), /// Signal from service to flush all messages and stop processing. Close, } pin_project! { /// A future that reads frames from a [`Framed`] object and passes them to a [`Service`]. pub struct Dispatcher where S: Service<::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead, T: AsyncWrite, U: Encoder, U: Decoder, I: 'static, >::Error: fmt::Debug, { service: S, state: State, #[pin] framed: Framed, rx: mpsc::Receiver, S::Error>>, tx: mpsc::Sender, S::Error>>, } } enum State where S: Service<::Item>, U: Encoder + Decoder, { Processing, Error(DispatcherError), FramedError(DispatcherError), FlushAndStop, Stopping, } impl State where S: Service<::Item>, U: Encoder + Decoder, { fn take_error(&mut self) -> DispatcherError { match mem::replace(self, State::Processing) { State::Error(err) => err, _ => panic!(), } } fn take_framed_error(&mut self) -> DispatcherError { match mem::replace(self, State::Processing) { State::FramedError(err) => err, _ => panic!(), } } } impl Dispatcher where S: Service<::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, ::Error: fmt::Debug, >::Error: fmt::Debug, { /// Create new `Dispatcher`. pub fn new(framed: Framed, service: F) -> Self where F: IntoService::Item>, { let (tx, rx) = mpsc::channel(); Dispatcher { framed, rx, tx, service: service.into_service(), state: State::Processing, } } /// Construct new `Dispatcher` instance with customer `mpsc::Receiver` pub fn with_rx( framed: Framed, service: F, rx: mpsc::Receiver, S::Error>>, ) -> Self where F: IntoService::Item>, { let tx = rx.sender(); Dispatcher { framed, rx, tx, service: service.into_service(), state: State::Processing, } } /// Get sender handle. pub fn tx(&self) -> mpsc::Sender, S::Error>> { self.tx.clone() } /// Get reference to a service wrapped by `Dispatcher` instance. pub fn service(&self) -> &S { &self.service } /// Get mutable reference to a service wrapped by `Dispatcher` instance. pub fn service_mut(&mut self) -> &mut S { &mut self.service } /// Get reference to a framed instance wrapped by `Dispatcher` instance. pub fn framed(&self) -> &Framed { &self.framed } /// Get mutable reference to a framed instance wrapped by `Dispatcher` instance. pub fn framed_mut(&mut self) -> &mut Framed { &mut self.framed } /// Read from framed object. fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool where S: Service<::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, >::Error: fmt::Debug, { loop { let this = self.as_mut().project(); match this.service.poll_ready(cx) { Poll::Ready(Ok(_)) => { let item = match this.framed.next_item(cx) { Poll::Ready(Some(Ok(el))) => el, Poll::Ready(Some(Err(err))) => { *this.state = State::FramedError(DispatcherError::Decoder(err)); return true; } Poll::Pending => return false, Poll::Ready(None) => { *this.state = State::Stopping; return true; } }; let tx = this.tx.clone(); let fut = this.service.call(item); actix_rt::spawn(async move { let item = fut.await; let _ = tx.send(item.map(Message::Item)); }); } Poll::Pending => return false, Poll::Ready(Err(err)) => { *this.state = State::Error(DispatcherError::Service(err)); return true; } } } } /// Write to framed object. fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool where S: Service<::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, >::Error: fmt::Debug, { loop { let mut this = self.as_mut().project(); while !this.framed.is_write_buf_full() { match Pin::new(&mut this.rx).poll_next(cx) { Poll::Ready(Some(Ok(Message::Item(msg)))) => { if let Err(err) = this.framed.as_mut().write(msg) { *this.state = State::FramedError(DispatcherError::Encoder(err)); return true; } } Poll::Ready(Some(Ok(Message::Close))) => { *this.state = State::FlushAndStop; return true; } Poll::Ready(Some(Err(err))) => { *this.state = State::Error(DispatcherError::Service(err)); return true; } Poll::Ready(None) | Poll::Pending => break, } } if !this.framed.is_write_buf_empty() { match this.framed.flush(cx) { Poll::Pending => break, Poll::Ready(Ok(_)) => {} Poll::Ready(Err(err)) => { debug!("Error sending data: {:?}", err); *this.state = State::FramedError(DispatcherError::Encoder(err)); return true; } } } else { break; } } false } } impl Future for Dispatcher where S: Service<::Item, Response = I>, S::Error: 'static, S::Future: 'static, T: AsyncRead + AsyncWrite, U: Decoder + Encoder, I: 'static, >::Error: fmt::Debug, ::Error: fmt::Debug, { type Output = Result<(), DispatcherError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { let this = self.as_mut().project(); return match this.state { State::Processing => { if self.as_mut().poll_read(cx) || self.as_mut().poll_write(cx) { continue; } else { Poll::Pending } } State::Error(_) => { // flush write buffer if !this.framed.is_write_buf_empty() && this.framed.flush(cx).is_pending() { return Poll::Pending; } Poll::Ready(Err(this.state.take_error())) } State::FlushAndStop => { if !this.framed.is_write_buf_empty() { this.framed.flush(cx).map(|res| { if let Err(err) = res { debug!("Error sending data: {:?}", err); } Ok(()) }) } else { Poll::Ready(Ok(())) } } State::FramedError(_) => Poll::Ready(Err(this.state.take_framed_error())), State::Stopping => Poll::Ready(Ok(())), }; } } } }