use std::collections::VecDeque; use std::io; use std::marker::PhantomData; use std::net::SocketAddr; use futures::{ future::{ok, FutureResult}, Async, Future, Poll, }; use tokio; use tokio_tcp::{ConnectFuture, TcpStream}; use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; use trust_dns_resolver::lookup_ip::LookupIpFuture; use trust_dns_resolver::system_conf::read_system_conf; use trust_dns_resolver::{AsyncResolver, Background}; use super::{NewService, Service}; pub trait HostAware { fn host(&self) -> &str; } impl HostAware for String { fn host(&self) -> &str { self.as_ref() } } #[derive(Fail, Debug)] pub enum ConnectorError { /// Failed to resolve the hostname #[fail(display = "Failed resolving hostname: {}", _0)] Resolver(String), /// Address is invalid #[fail(display = "Invalid input: {}", _0)] InvalidInput(&'static str), /// Connection io error #[fail(display = "{}", _0)] IoError(io::Error), } pub struct ConnectionInfo { pub host: String, pub addr: SocketAddr, } pub struct Connector { resolver: AsyncResolver, req: PhantomData, } impl Default for Connector { fn default() -> Self { let (cfg, opts) = if let Ok((cfg, opts)) = read_system_conf() { (cfg, opts) } else { (ResolverConfig::default(), ResolverOpts::default()) }; Connector::new(cfg, opts) } } impl Connector { pub fn new(cfg: ResolverConfig, opts: ResolverOpts) -> Self { let (resolver, bg) = AsyncResolver::new(cfg, opts); tokio::spawn(bg); Connector { resolver, req: PhantomData, } } pub fn new_service() -> impl NewService< Request = T, Response = (T, ConnectionInfo, TcpStream), Error = ConnectorError, InitError = E, > + Clone { || -> FutureResult, E> { ok(Connector::default()) } } pub fn new_service_with_config( cfg: ResolverConfig, opts: ResolverOpts, ) -> impl NewService< Request = T, Response = (T, ConnectionInfo, TcpStream), Error = ConnectorError, InitError = E, > + Clone { move || -> FutureResult, E> { ok(Connector::new(cfg.clone(), opts)) } } pub fn change_request(&self) -> Connector { Connector { resolver: self.resolver.clone(), req: PhantomData, } } } impl Clone for Connector { fn clone(&self) -> Self { Connector { resolver: self.resolver.clone(), req: PhantomData, } } } impl Service for Connector { type Request = T; type Response = (T, ConnectionInfo, TcpStream); type Error = ConnectorError; type Future = ConnectorFuture; fn poll_ready(&mut self) -> Poll<(), Self::Error> { Ok(Async::Ready(())) } fn call(&mut self, req: Self::Request) -> Self::Future { let fut = ResolveFut::new(req, 0, &self.resolver); ConnectorFuture { fut, fut2: None } } } pub struct ConnectorFuture { fut: ResolveFut, fut2: Option>, } impl Future for ConnectorFuture { type Item = (T, ConnectionInfo, TcpStream); type Error = ConnectorError; fn poll(&mut self) -> Poll { if let Some(ref mut fut) = self.fut2 { return fut.poll(); } match self.fut.poll()? { Async::Ready((req, host, addrs)) => { self.fut2 = Some(TcpConnector::new(req, host, addrs)); self.poll() } Async::NotReady => Ok(Async::NotReady), } } } /// Resolver future struct ResolveFut { req: Option, host: Option, port: u16, lookup: Option>, addrs: Option>, error: Option, error2: Option, } impl ResolveFut { pub fn new(addr: T, port: u16, resolver: &AsyncResolver) -> Self { // we need to do dns resolution match ResolveFut::::parse(addr.host(), port) { Ok((host, port)) => { let lookup = Some(resolver.lookup_ip(host.as_str())); ResolveFut { port, lookup, req: Some(addr), host: Some(host), addrs: None, error: None, error2: None, } } Err(err) => ResolveFut { port, req: None, host: None, lookup: None, addrs: None, error: Some(err), error2: None, }, } } fn parse(addr: &str, port: u16) -> Result<(String, u16), ConnectorError> { macro_rules! try_opt { ($e:expr, $msg:expr) => { match $e { Some(r) => r, None => return Err(ConnectorError::InvalidInput($msg)), } }; } // split the string by ':' and convert the second part to u16 let mut parts_iter = addr.splitn(2, ':'); let host = try_opt!(parts_iter.next(), "invalid socket address"); let port_str = parts_iter.next().unwrap_or(""); let port: u16 = port_str.parse().unwrap_or(port); Ok((host.to_owned(), port)) } } impl Future for ResolveFut { type Item = (T, String, VecDeque); type Error = ConnectorError; fn poll(&mut self) -> Poll { if let Some(err) = self.error.take() { Err(err) } else if let Some(err) = self.error2.take() { Err(ConnectorError::Resolver(err)) } else if let Some(addrs) = self.addrs.take() { Ok(Async::Ready(( self.req.take().unwrap(), self.host.take().unwrap(), addrs, ))) } else { match self.lookup.as_mut().unwrap().poll() { Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::Ready(ips)) => { let addrs: VecDeque<_> = ips .iter() .map(|ip| SocketAddr::new(ip, self.port)) .collect(); if addrs.is_empty() { Err(ConnectorError::Resolver( "Expect at least one A dns record".to_owned(), )) } else { Ok(Async::Ready(( self.req.take().unwrap(), self.host.take().unwrap(), addrs, ))) } } Err(err) => Err(ConnectorError::Resolver(format!("{}", err))), } } } } /// Tcp stream connector pub struct TcpConnector { req: Option, host: Option, addr: Option, addrs: VecDeque, stream: Option, } impl TcpConnector { pub fn new(req: T, host: String, addrs: VecDeque) -> TcpConnector { TcpConnector { addrs, req: Some(req), host: Some(host), addr: None, stream: None, } } } impl Future for TcpConnector { type Item = (T, ConnectionInfo, TcpStream); type Error = ConnectorError; fn poll(&mut self) -> Poll { // connect loop { if let Some(new) = self.stream.as_mut() { match new.poll() { Ok(Async::Ready(sock)) => { return Ok(Async::Ready(( self.req.take().unwrap(), ConnectionInfo { host: self.host.take().unwrap(), addr: self.addr.take().unwrap(), }, sock, ))) } Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => { if self.addrs.is_empty() { return Err(ConnectorError::IoError(err)); } } } } // try to connect let addr = self.addrs.pop_front().unwrap(); self.stream = Some(TcpStream::connect(&addr)); self.addr = Some(addr) } } }