From d7d3d663e9e9665a7d162231df71818fbb6f0eb2 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 4 Nov 2017 12:33:14 -0700 Subject: [PATCH] refactor server impl and add support for alpn http2 negotiation --- Cargo.toml | 12 ++- README.md | 19 ++++- examples/tls/Cargo.toml | 2 +- examples/tls/src/main.rs | 2 +- src/application.rs | 2 +- src/channel.rs | 98 +++++++++++++++++++++++++ src/h1.rs | 14 ++-- src/h1writer.rs | 21 +++--- src/h2.rs | 2 +- src/h2writer.rs | 32 ++++---- src/lib.rs | 10 +++ src/recognizer.rs | 2 +- src/server.rs | 154 ++++++++++++++++++--------------------- src/ws.rs | 6 +- 14 files changed, 240 insertions(+), 136 deletions(-) create mode 100644 src/channel.rs diff --git a/Cargo.toml b/Cargo.toml index 58dd2a3fe..41a3568bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,9 @@ default = [] # tls tls = ["native-tls", "tokio-tls"] +# openssl +alpn = ["openssl", "openssl/v102", "openssl/v110", "tokio-openssl"] + [dependencies] log = "0.3" time = "0.1" @@ -50,10 +53,13 @@ tokio-core = "0.1" h2 = { git = 'https://github.com/carllerche/h2' } -# tls +# native-tls native-tls = { version="0.1", optional = true } tokio-tls = { version="0.1", optional = true } +# openssl +tokio-openssl = { version="0.1", optional = true } + [dependencies.actix] version = ">=0.3.1" #path = "../actix" @@ -61,6 +67,10 @@ version = ">=0.3.1" default-features = false features = [] +[dependencies.openssl] +version = "0.9" +optional = true + [dev-dependencies] env_logger = "0.4" reqwest = "0.8" diff --git a/README.md b/README.md index 543c30f9f..affb1b881 100644 --- a/README.md +++ b/README.md @@ -13,14 +13,25 @@ Actix web is licensed under the [Apache-2.0 license](http://opensource.org/licen ## Features - * HTTP/1 and HTTP/2 support - * Streaming and pipelining support - * Keep-alive and slow requests support - * [WebSockets support](https://actix.github.io/actix-web/actix_web/ws/index.html) + * HTTP/1 and HTTP/2 + * Streaming and pipelining + * Keep-alive and slow requests handling + * [WebSockets](https://actix.github.io/actix-web/actix_web/ws/index.html) * Configurable request routing * Multipart streams * Middlewares +## HTTP/2 Negotiation + +To use http/2 protocol over tls without prior knowlage requires +[tls alpn]( (https://tools.ietf.org/html/rfc7301). At the moment only +rust-openssl supports alpn. + +```toml +[dependencies] +actix-web = { git = "https://github.com/actix/actix-web", features=["alpn"] } +``` + ## Usage To use `actix-web`, add this to your `Cargo.toml`: diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index bb983dc49..1c73ab396 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -11,4 +11,4 @@ path = "src/main.rs" env_logger = "0.4" actix = "0.3.1" -actix-web = { path = "../../", features=["tls"] } +actix-web = { path = "../../", features=["alpn"] } diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 5e2d37544..49d8acbb2 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -26,7 +26,7 @@ fn main() { let mut file = File::open("identity.pfx").unwrap(); let mut pkcs12 = vec![]; file.read_to_end(&mut pkcs12).unwrap(); - let pkcs12 = Pkcs12::from_der(&pkcs12, "12345").unwrap(); + let pkcs12 = Pkcs12::from_der(&pkcs12).unwrap().parse("12345").unwrap(); HttpServer::new( Application::default("/") diff --git a/src/application.rs b/src/application.rs index f3acb8697..a2badbd50 100644 --- a/src/application.rs +++ b/src/application.rs @@ -9,7 +9,7 @@ use resource::Resource; use recognizer::{RouteRecognizer, check_pattern}; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use server::HttpHandler; +use channel::HttpHandler; /// Middleware definition diff --git a/src/channel.rs b/src/channel.rs new file mode 100644 index 000000000..3c265c025 --- /dev/null +++ b/src/channel.rs @@ -0,0 +1,98 @@ +use std::rc::Rc; + +use actix::dev::*; +use bytes::Bytes; +use futures::{Future, Poll, Async}; +use tokio_io::{AsyncRead, AsyncWrite}; + +use h1; +use h2; +use task::Task; +use payload::Payload; +use httprequest::HttpRequest; + +/// Low level http request handler +pub trait HttpHandler: 'static { + /// Http handler prefix + fn prefix(&self) -> &str; + /// Handle request + fn handle(&self, req: &mut HttpRequest, payload: Payload) -> Task; +} + +enum HttpProtocol + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +{ + H1(h1::Http1), + H2(h2::Http2), +} + +pub struct HttpChannel + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +{ + proto: Option>, +} + +impl HttpChannel + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static +{ + pub fn new(stream: T, addr: A, router: Rc>, http2: bool) -> HttpChannel { + if http2 { + HttpChannel { + proto: Some(HttpProtocol::H2( + h2::Http2::new(stream, addr, router, Bytes::new()))) } + } else { + HttpChannel { + proto: Some(HttpProtocol::H1( + h1::Http1::new(stream, addr, router))) } + } + } +} + +/*impl Drop for HttpChannel { + fn drop(&mut self) { + println!("Drop http channel"); + } +}*/ + +impl Actor for HttpChannel + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static +{ + type Context = Context; +} + +impl Future for HttpChannel + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll { + match self.proto { + Some(HttpProtocol::H1(ref mut h1)) => { + match h1.poll() { + Ok(Async::Ready(h1::Http1Result::Done)) => + return Ok(Async::Ready(())), + Ok(Async::Ready(h1::Http1Result::Upgrade)) => (), + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(_) => + return Err(()), + } + } + Some(HttpProtocol::H2(ref mut h2)) => + return h2.poll(), + None => unreachable!(), + } + + // upgrade to h2 + let proto = self.proto.take().unwrap(); + match proto { + HttpProtocol::H1(h1) => { + let (stream, addr, router, buf) = h1.into_inner(); + self.proto = Some(HttpProtocol::H2(h2::Http2::new(stream, addr, router, buf))); + self.poll() + } + _ => unreachable!() + } + } +} diff --git a/src/h1.rs b/src/h1.rs index 78ad3ad32..cb2a156d9 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -15,7 +15,7 @@ use tokio_core::reactor::Timeout; use percent_encoding; use task::Task; -use server::HttpHandler; +use channel::HttpHandler; use error::ParseError; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; @@ -75,7 +75,7 @@ impl Http1 } pub fn into_inner(mut self) -> (T, A, Rc>, Bytes) { - (self.stream.into_inner(), self.addr, self.router, self.read_buf.freeze()) + (self.stream.unwrap(), self.addr, self.router, self.read_buf.freeze()) } pub fn poll(&mut self) -> Poll { @@ -114,7 +114,7 @@ impl Http1 if self.keepalive { self.keepalive = self.stream.keepalive(); } - self.stream = H1Writer::new(self.stream.into_inner()); + self.stream = H1Writer::new(self.stream.unwrap()); item.eof = true; if ready { @@ -251,12 +251,12 @@ impl Http1 // check for parse error if self.tasks.is_empty() { + if self.h2 { + return Ok(Async::Ready(Http1Result::Upgrade)) + } if self.error || self.keepalive_timer.is_none() { return Ok(Async::Ready(Http1Result::Done)) } - else if self.h2 { - return Ok(Async::Ready(Http1Result::Upgrade)) - } } if not_ready { @@ -482,7 +482,7 @@ impl Reader { if buf.is_empty() { return Ok(Message::NotReady); } - if buf.len() >= 14 && &buf[..14] == &HTTP2_PREFACE[..] { + if buf.len() >= 14 && buf[..14] == HTTP2_PREFACE[..] { return Ok(Message::Http2) } diff --git a/src/h1writer.rs b/src/h1writer.rs index 0aa7b94d3..7a7d08001 100644 --- a/src/h1writer.rs +++ b/src/h1writer.rs @@ -63,7 +63,7 @@ impl H1Writer { self.stream.as_mut().unwrap() } - pub fn into_inner(&mut self) -> T { + pub fn unwrap(&mut self) -> T { self.stream.take().unwrap() } @@ -90,12 +90,11 @@ impl H1Writer { return Ok(WriterState::Done) } } - Err(err) => - return Err(err), + Err(err) => return Err(err), } } } - return Ok(WriterState::Done) + Ok(WriterState::Done) } } @@ -225,9 +224,9 @@ impl Writer for H1Writer { } if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { - return Ok(WriterState::Pause) + Ok(WriterState::Pause) } else { - return Ok(WriterState::Done) + Ok(WriterState::Done) } } @@ -236,12 +235,10 @@ impl Writer for H1Writer { //debug!("last payload item, but it is not EOF "); Err(io::Error::new(io::ErrorKind::Other, "Last payload item, but eof is not reached")) + } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + Ok(WriterState::Pause) } else { - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { - return Ok(WriterState::Pause) - } else { - return Ok(WriterState::Done) - } + Ok(WriterState::Done) } } @@ -345,7 +342,7 @@ impl Encoder { true }, Kind::Length(ref mut remaining) => { - return *remaining == 0 + *remaining == 0 }, } } diff --git a/src/h2.rs b/src/h2.rs index 9c89c18e5..3eb2c0bf6 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -12,7 +12,7 @@ use futures::{Async, Poll, Future, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; use task::Task; -use server::HttpHandler; +use channel::HttpHandler; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; use payload::{Payload, PayloadError, PayloadSender}; diff --git a/src/h2writer.rs b/src/h2writer.rs index ae3c9bc82..d97f0a542 100644 --- a/src/h2writer.rs +++ b/src/h2writer.rs @@ -77,15 +77,13 @@ impl H2Writer { let bytes = self.buffer.split_to(cmp::min(cap, len)); let eof = self.buffer.is_empty() && self.eof; - if let Err(_) = stream.send_data(bytes.freeze(), eof) { - return Err(io::Error::new(io::ErrorKind::Other, "")) + if let Err(err) = stream.send_data(bytes.freeze(), eof) { + return Err(io::Error::new(io::ErrorKind::Other, err)) + } else if !self.buffer.is_empty() { + let cap = cmp::min(self.buffer.len(), CHUNK_SIZE); + stream.reserve_capacity(cap); } else { - if !self.buffer.is_empty() { - let cap = cmp::min(self.buffer.len(), CHUNK_SIZE); - stream.reserve_capacity(cap); - } else { - return Ok(WriterState::Done) - } + return Ok(WriterState::Done) } } Err(_) => { @@ -94,7 +92,7 @@ impl H2Writer { } } } - return Ok(WriterState::Done) + Ok(WriterState::Done) } } @@ -200,9 +198,9 @@ impl Writer for H2Writer { } if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { - return Ok(WriterState::Pause) + Ok(WriterState::Pause) } else { - return Ok(WriterState::Done) + Ok(WriterState::Done) } } @@ -211,12 +209,10 @@ impl Writer for H2Writer { if !self.encoder.encode_eof(&mut self.buffer) { Err(io::Error::new(io::ErrorKind::Other, "Last payload item, but eof is not reached")) + } else if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + Ok(WriterState::Pause) } else { - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { - return Ok(WriterState::Pause) - } else { - return Ok(WriterState::Done) - } + Ok(WriterState::Done) } } @@ -288,9 +284,7 @@ impl Encoder { pub fn encode_eof(&mut self, _dst: &mut BytesMut) -> bool { match self.kind { Kind::Eof => true, - Kind::Length(ref mut remaining) => { - return *remaining == 0 - }, + Kind::Length(ref mut remaining) => *remaining == 0 } } } diff --git a/src/lib.rs b/src/lib.rs index d42004a4b..a4f875998 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,11 @@ extern crate native_tls; #[cfg(feature="tls")] extern crate tokio_tls; +#[cfg(feature="openssl")] +extern crate openssl; +#[cfg(feature="openssl")] +extern crate tokio_openssl; + mod application; mod body; mod context; @@ -42,6 +47,7 @@ mod route; mod task; mod staticfiles; mod server; +mod channel; mod wsframe; mod wsproto; mod h1; @@ -65,6 +71,7 @@ pub use recognizer::{Params, RouteRecognizer}; pub use logger::Logger; pub use server::HttpServer; pub use context::HttpContext; +pub use channel::HttpChannel; pub use staticfiles::StaticFiles; // re-exports @@ -75,3 +82,6 @@ pub use http_range::{HttpRange, HttpRangeParseError}; #[cfg(feature="tls")] pub use native_tls::Pkcs12; + +#[cfg(feature="openssl")] +pub use openssl::pkcs12::Pkcs12; diff --git a/src/recognizer.rs b/src/recognizer.rs index 312c2b375..71acfd4ac 100644 --- a/src/recognizer.rs +++ b/src/recognizer.rs @@ -116,7 +116,7 @@ pub(crate) fn check_pattern(path: &str) { } fn parse(pattern: &str) -> String { - const DEFAULT_PATTERN: &'static str = "[^/]+"; + const DEFAULT_PATTERN: &str = "[^/]+"; let mut re = String::from("^/"); let mut in_param = false; diff --git a/src/server.rs b/src/server.rs index c7cd6613f..6fb1f23ef 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,30 +1,26 @@ -use std::{io, net, mem}; +use std::{io, net}; use std::rc::Rc; use std::marker::PhantomData; use actix::dev::*; -use futures::{Future, Poll, Async, Stream}; -use tokio_core::net::{TcpListener, TcpStream}; +use futures::Stream; use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_core::net::{TcpListener, TcpStream}; #[cfg(feature="tls")] use native_tls::TlsAcceptor; #[cfg(feature="tls")] use tokio_tls::{TlsStream, TlsAcceptorExt}; -use h1; -use h2; -use task::Task; -use payload::Payload; -use httprequest::HttpRequest; +#[cfg(feature="alpn")] +use openssl::ssl::{SslMethod, SslAcceptorBuilder}; +#[cfg(feature="alpn")] +use openssl::pkcs12::ParsedPkcs12; +#[cfg(feature="alpn")] +use tokio_openssl::{SslStream, SslAcceptorExt}; + +use channel::{HttpChannel, HttpHandler}; -/// Low level http request handler -pub trait HttpHandler: 'static { - /// Http handler prefix - fn prefix(&self) -> &str; - /// Handle request - fn handle(&self, req: &mut HttpRequest, payload: Payload) -> Task; -} /// An HTTP Server /// @@ -66,7 +62,7 @@ impl HttpServer S: Stream + 'static { Ok(HttpServer::create(move |ctx| { - ctx.add_stream(stream.map(|(t, a)| IoStream(t, a))); + ctx.add_stream(stream.map(|(t, a)| IoStream(t, a, false))); self })) } @@ -111,7 +107,7 @@ impl HttpServer { Ok(HttpServer::create(move |ctx| { for (addr, tcp) in addrs { info!("Starting http server on {}", addr); - ctx.add_stream(tcp.incoming().map(|(t, a)| IoStream(t, a))); + ctx.add_stream(tcp.incoming().map(|(t, a)| IoStream(t, a, false))); } self })) @@ -161,7 +157,61 @@ impl HttpServer, net::SocketAddr, H> { } } -struct IoStream(T, A); +#[cfg(feature="alpn")] +impl HttpServer, net::SocketAddr, H> { + + /// Start listening for incomming tls connections. + /// + /// This methods converts address to list of `SocketAddr` + /// then binds to all available addresses. + pub fn serve_tls(self, addr: S, identity: ParsedPkcs12) -> io::Result + where Self: ActorAddress, + S: net::ToSocketAddrs, + { + let addrs = self.bind(addr)?; + let acceptor = match SslAcceptorBuilder::mozilla_intermediate(SslMethod::tls(), + &identity.pkey, + &identity.cert, + &identity.chain) + { + Ok(mut builder) => { + match builder.builder_mut().set_alpn_protocols(&[b"h2", b"http/1.1"]) { + Ok(_) => builder.build(), + Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err)), + } + }, + Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err)) + }; + + Ok(HttpServer::create(move |ctx| { + for (addr, tcp) in addrs { + info!("Starting tls http server on {}", addr); + + let acc = acceptor.clone(); + ctx.add_stream(tcp.incoming().and_then(move |(stream, addr)| { + SslAcceptorExt::accept_async(&acc, stream) + .map(move |stream| { + let http2 = if let Some(p) = + stream.get_ref().ssl().selected_alpn_protocol() + { + p.len() == 2 && &p == b"h2" + } else { + false + }; + IoStream(stream, addr, http2) + }) + .map_err(|err| { + trace!("Error during handling tls connection: {}", err); + io::Error::new(io::ErrorKind::Other, err) + }) + })); + } + self + })) + } +} + +struct IoStream(T, A, bool); impl ResponseType for IoStream where T: AsyncRead + AsyncWrite + 'static, @@ -189,73 +239,7 @@ impl Handler, io::Error> for HttpServer -> Response> { Arbiter::handle().spawn( - HttpChannel{ - proto: Protocol::H1(h1::Http1::new(msg.0, msg.1, Rc::clone(&self.h))) - }); + HttpChannel::new(msg.0, msg.1, Rc::clone(&self.h), msg.2)); Self::empty() } } - -enum Protocol - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static -{ - H1(h1::Http1), - H2(h2::Http2), - None, -} - -pub struct HttpChannel - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static -{ - proto: Protocol, -} - -/*impl Drop for HttpChannel { - fn drop(&mut self) { - println!("Drop http channel"); - } -}*/ - -impl Actor for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static -{ - type Context = Context; -} - -impl Future for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static -{ - type Item = (); - type Error = (); - - fn poll(&mut self) -> Poll { - match self.proto { - Protocol::H1(ref mut h1) => { - match h1.poll() { - Ok(Async::Ready(h1::Http1Result::Done)) => - return Ok(Async::Ready(())), - Ok(Async::Ready(h1::Http1Result::Upgrade)) => (), - Ok(Async::NotReady) => - return Ok(Async::NotReady), - Err(_) => - return Err(()), - } - } - Protocol::H2(ref mut h2) => - return h2.poll(), - Protocol::None => - unreachable!() - } - - // upgrade to h2 - let proto = mem::replace(&mut self.proto, Protocol::None); - match proto { - Protocol::H1(h1) => { - let (stream, addr, router, buf) = h1.into_inner(); - self.proto = Protocol::H2(h2::Http2::new(stream, addr, router, buf)); - return self.poll() - } - _ => unreachable!() - } - } -} diff --git a/src/ws.rs b/src/ws.rs index 1319418f9..78dc85f99 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -80,11 +80,11 @@ use wsproto::*; pub use wsproto::CloseCode; #[doc(hidden)] -const SEC_WEBSOCKET_ACCEPT: &'static str = "SEC-WEBSOCKET-ACCEPT"; +const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; #[doc(hidden)] -const SEC_WEBSOCKET_KEY: &'static str = "SEC-WEBSOCKET-KEY"; +const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; #[doc(hidden)] -const SEC_WEBSOCKET_VERSION: &'static str = "SEC-WEBSOCKET-VERSION"; +const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION"; // #[doc(hidden)] // const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL";