diff --git a/.travis.yml b/.travis.yml index b0c4c6e73..640aa1b92 100644 --- a/.travis.yml +++ b/.travis.yml @@ -47,7 +47,7 @@ script: USE_SKEPTIC=1 cargo test --features=alpn else cargo clean - cargo test + cargo test -- --nocapture # --features=alpn fi @@ -55,8 +55,10 @@ script: if [[ "$TRAVIS_RUST_VERSION" == "stable" ]]; then cd examples/basics && cargo check && cd ../.. cd examples/hello-world && cargo check && cd ../.. + cd examples/http-proxy && cargo check && cd ../.. cd examples/multipart && cargo check && cd ../.. cd examples/json && cargo check && cd ../.. + cd examples/juniper && cargo check && cd ../.. cd examples/state && cargo check && cd ../.. cd examples/template_tera && cargo check && cd ../.. cd examples/diesel && cargo check && cd ../.. @@ -64,6 +66,7 @@ script: cd examples/tls && cargo check && cd ../.. cd examples/websocket-chat && cargo check && cd ../.. cd examples/websocket && cargo check && cd ../.. + cd examples/unix-socket && cargo check && cd ../.. fi - | if [[ "$TRAVIS_RUST_VERSION" == "nightly" && $CLIPPY ]]; then @@ -73,7 +76,7 @@ script: # Upload docs after_success: - | - if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "stable" ]]; then + if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "nightly" ]]; then cargo doc --features "alpn, tls" --no-deps && echo "" > target/doc/index.html && cargo install mdbook && diff --git a/CHANGES.md b/CHANGES.md index b623c163f..278e4780e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,13 +1,60 @@ # Changes -## 0.4.0 (2018-02-..) +## 0.4.5 (2018-03-xx) + +* Enable compression support for `NamedFile` + +* Add `ResponseError` impl for `SendRequestError`. + This improves ergonomics of http client. + + +## 0.4.4 (2018-03-04) + +* Allow to use Arc> as response/request body + +* Fix handling of requests with an encoded body with a length > 8192 #93 + +## 0.4.3 (2018-03-03) + +* Fix request body read bug + +* Fix segmentation fault #79 + +* Set reuse address before bind #90 + + +## 0.4.2 (2018-03-02) + +* Better naming for websockets implementation + +* Add `Pattern::with_prefix()`, make it more usable outside of actix + +* Add csrf middleware for filter for cross-site request forgery #89 + +* Fix disconnect on idle connections + + +## 0.4.1 (2018-03-01) + +* Rename `Route::p()` to `Route::filter()` + +* Better naming for http codes + +* Fix payload parse in situation when socket data is not ready. + +* Fix Session mutable borrow lifetime #87 + + +## 0.4.0 (2018-02-28) * Actix 0.5 compatibility -* Fix request json loader +* Fix request json/urlencoded loaders * Simplify HttpServer type definition +* Added HttpRequest::encoding() method + * Added HttpRequest::mime_type() method * Added HttpRequest::uri_mut(), allows to modify request uri @@ -16,11 +63,11 @@ * Added http client -* Added basic websocket client +* Added websocket client * Added TestServer::ws(), test websockets client -* Added TestServer test http client +* Added TestServer http client support * Allow to override content encoding on application level diff --git a/Cargo.toml b/Cargo.toml index b7999b743..9b69e2560 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,17 @@ [package] name = "actix-web" -version = "0.4.0" +version = "0.4.4" authors = ["Nikolay Kim "] -description = "Actix web framework" +description = "Actix web is a small, pragmatic, extremely fast, web framework for Rust." readme = "README.md" keywords = ["http", "web", "framework", "async", "futures"] homepage = "https://github.com/actix/actix-web" repository = "https://github.com/actix/actix-web.git" documentation = "https://docs.rs/actix-web/" categories = ["network-programming", "asynchronous", - "web-programming::http-server", "web-programming::websocket"] + "web-programming::http-server", + "web-programming::http-client", + "web-programming::websocket"] license = "MIT/Apache-2.0" exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml", "/examples/**"] @@ -40,7 +42,7 @@ brotli2 = "^0.3.2" failure = "0.1.1" flate2 = "1.0" h2 = "0.1" -http = "^0.1.2" +http = "^0.1.5" httparse = "1.2" http-range = "0.1" libc = "0.2" @@ -53,10 +55,11 @@ rand = "0.4" regex = "0.2" serde = "1.0" serde_json = "1.0" -sha1 = "0.4" +sha1 = "0.6" smallvec = "0.6" time = "0.1" -url = "1.6" +encoding = "0.2" +url = { version="1.7", features=["query_encoding"] } cookie = { version="0.10", features=["percent-encode", "secure"] } # io @@ -78,7 +81,7 @@ openssl = { version="0.10", optional = true } tokio-openssl = { version="0.2", optional = true } [dependencies.actix] -version = "0.5" +version = "^0.5.1" [dev-dependencies] env_logger = "0.5" @@ -98,16 +101,20 @@ codegen-units = 1 members = [ "./", "examples/basics", + "examples/juniper", "examples/diesel", "examples/r2d2", "examples/json", "examples/hello-world", + "examples/http-proxy", "examples/multipart", "examples/state", + "examples/redis-session", "examples/template_tera", "examples/tls", "examples/websocket", "examples/websocket-chat", "examples/web-cors/backend", + "examples/unix-socket", "tools/wsload/", ] diff --git a/README.md b/README.md index 47d7db733..7c8ac4e0a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Actix web [![Build Status](https://travis-ci.org/actix/actix-web.svg?branch=master)](https://travis-ci.org/actix/actix-web) [![Build status](https://ci.appveyor.com/api/projects/status/kkdb4yce7qhm5w85/branch/master?svg=true)](https://ci.appveyor.com/project/fafhrd91/actix-web-hdy9d/branch/master) [![codecov](https://codecov.io/gh/actix/actix-web/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-web) [![crates.io](http://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -Actix web is a small, fast, pragmatic, open source rust web framework. +Actix web is a small, pragmatic, extremely fast, web framework for Rust. * Supported *HTTP/1.x* and [*HTTP/2.0*](https://actix.github.io/actix-web/guide/qs_13.html) protocols * Streaming and pipelining @@ -10,11 +10,13 @@ Actix web is a small, fast, pragmatic, open source rust web framework. * Configurable [request routing](https://actix.github.io/actix-web/guide/qs_5.html) * Graceful server shutdown * Multipart streams +* SSL support with openssl or native-tls * Middlewares ([Logger](https://actix.github.io/actix-web/guide/qs_10.html#logging), [Session](https://actix.github.io/actix-web/guide/qs_10.html#user-sessions), + [Redis sessions](https://github.com/actix/actix-redis), [DefaultHeaders](https://actix.github.io/actix-web/guide/qs_10.html#default-headers), [CORS](https://actix.github.io/actix-web/actix_web/middleware/cors/index.html)) -* Built on top of [Actix](https://github.com/actix/actix). +* Built on top of [Actix actor framework](https://github.com/actix/actix). ## Documentation @@ -48,7 +50,7 @@ fn main() { * [Basics](https://github.com/actix/actix-web/tree/master/examples/basics/) * [Stateful](https://github.com/actix/actix-web/tree/master/examples/state/) -* [Mulitpart streams](https://github.com/actix/actix-web/tree/master/examples/multipart/) +* [Multipart streams](https://github.com/actix/actix-web/tree/master/examples/multipart/) * [Simple websocket session](https://github.com/actix/actix-web/tree/master/examples/websocket/) * [Tera templates](https://github.com/actix/actix-web/tree/master/examples/template_tera/) * [Diesel integration](https://github.com/actix/actix-web/tree/master/examples/diesel/) @@ -57,11 +59,14 @@ fn main() { * [SockJS Server](https://github.com/actix/actix-sockjs) * [Json](https://github.com/actix/actix-web/tree/master/examples/json/) +You may consider checking out +[this directory](https://github.com/actix/actix-web/tree/master/examples) for more examples. + ## Benchmarks * [TechEmpower Framework Benchmark](https://www.techempower.com/benchmarks/#section=data-r15&hw=ph&test=plaintext) -* Some basic benchmarks could be found in this [respository](https://github.com/fafhrd91/benchmarks). +* Some basic benchmarks could be found in this [repository](https://github.com/fafhrd91/benchmarks). ## License diff --git a/examples/basics/src/main.rs b/examples/basics/src/main.rs index f52b09544..55e4485e0 100644 --- a/examples/basics/src/main.rs +++ b/examples/basics/src/main.rs @@ -22,7 +22,7 @@ fn index(mut req: HttpRequest) -> Result { println!("{:?}", req); // example of ... - if let Ok(ch) = req.payload_mut().readany().poll() { + if let Ok(ch) = req.poll() { if let futures::Async::Ready(Some(d)) = ch { println!("{}", String::from_utf8_lossy(d.as_ref())); } @@ -139,7 +139,7 @@ fn main() { // default .default_resource(|r| { r.method(Method::GET).f(p404); - r.route().p(pred::Not(pred::Get())).f(|req| httpcodes::HTTPMethodNotAllowed); + r.route().filter(pred::Not(pred::Get())).f(|req| httpcodes::HTTPMethodNotAllowed); })) .bind("127.0.0.1:8080").expect("Can not bind to 127.0.0.1:8080") diff --git a/examples/http-proxy/Cargo.toml b/examples/http-proxy/Cargo.toml new file mode 100644 index 000000000..7b9597bff --- /dev/null +++ b/examples/http-proxy/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "http-proxy" +version = "0.1.0" +authors = ["Nikolay Kim "] +workspace = "../.." + +[dependencies] +env_logger = "0.5" +futures = "0.1" +actix = "0.5" +actix-web = { path = "../../", features=["alpn"] } diff --git a/examples/http-proxy/src/main.rs b/examples/http-proxy/src/main.rs new file mode 100644 index 000000000..551101c97 --- /dev/null +++ b/examples/http-proxy/src/main.rs @@ -0,0 +1,59 @@ +extern crate actix; +extern crate actix_web; +extern crate futures; +extern crate env_logger; + +use actix_web::*; +use futures::{Future, Stream}; + + +/// Stream client request response and then send body to a server response +fn index(_req: HttpRequest) -> Box> { + client::ClientRequest::get("https://www.rust-lang.org/en-US/") + .finish().unwrap() + .send() + .map_err(error::Error::from) // <- convert SendRequestError to an Error + .and_then( + |resp| resp.body() // <- this is MessageBody type, resolves to complete body + .from_err() // <- convet PayloadError to a Error + .and_then(|body| { // <- we got complete body, now send as server response + httpcodes::HttpOk.build() + .body(body) + .map_err(error::Error::from) + })) + .responder() +} + +/// streaming client request to a streaming server response +fn streaming(_req: HttpRequest) -> Box> { + // send client request + client::ClientRequest::get("https://www.rust-lang.org/en-US/") + .finish().unwrap() + .send() // <- connect to host and send request + .map_err(error::Error::from) // <- convert SendRequestError to an Error + .and_then(|resp| { // <- we received client response + httpcodes::HttpOk.build() + // read one chunk from client response and send this chunk to a server response + // .from_err() converts PayloadError to a Error + .body(Body::Streaming(Box::new(resp.from_err()))) + .map_err(|e| e.into()) // HttpOk::build() mayb return HttpError, we need to convert it to a Error + }) + .responder() +} + +fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + env_logger::init(); + let sys = actix::System::new("http-proxy"); + + let _addr = HttpServer::new( + || Application::new() + .middleware(middleware::Logger::default()) + .resource("/streaming", |r| r.f(streaming)) + .resource("/", |r| r.f(index))) + .bind("127.0.0.1:8080").unwrap() + .start(); + + println!("Started http server: 127.0.0.1:8080"); + let _ = sys.run(); +} diff --git a/examples/json/src/main.rs b/examples/json/src/main.rs index 719d74853..3247e5d6c 100644 --- a/examples/json/src/main.rs +++ b/examples/json/src/main.rs @@ -34,9 +34,9 @@ fn index(req: HttpRequest) -> Box> { const MAX_SIZE: usize = 262_144; // max payload size is 256k /// This handler manually load request payload and parse serde json -fn index_manual(mut req: HttpRequest) -> Box> { - // readany() returns asynchronous stream of Bytes objects - req.payload_mut().readany() +fn index_manual(req: HttpRequest) -> Box> { + // HttpRequest is stream of Bytes objects + req // `Future::from_err` acts like `?` in that it coerces the error type from // the future into the final error type .from_err() @@ -63,8 +63,8 @@ fn index_manual(mut req: HttpRequest) -> Box Box> { - req.payload_mut().readany().concat2() +fn index_mjsonrust(req: HttpRequest) -> Box> { + req.concat2() .from_err() .and_then(|body| { // body is loaded, now we can deserialize json-rust diff --git a/examples/juniper/Cargo.toml b/examples/juniper/Cargo.toml new file mode 100644 index 000000000..9e52b0a83 --- /dev/null +++ b/examples/juniper/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "juniper-example" +version = "0.1.0" +authors = ["pyros2097 "] +workspace = "../.." + +[dependencies] +env_logger = "0.5" +actix = "0.5" +actix-web = { path = "../../" } + +futures = "0.1" +serde = "1.0" +serde_json = "1.0" +serde_derive = "1.0" + +juniper = "0.9.2" diff --git a/examples/juniper/README.md b/examples/juniper/README.md new file mode 100644 index 000000000..2ac0eac4e --- /dev/null +++ b/examples/juniper/README.md @@ -0,0 +1,15 @@ +# juniper + +Juniper integration for Actix web + +### server + +```bash +cd actix-web/examples/juniper +cargo run (or ``cargo watch -x run``) +# Started http server: 127.0.0.1:8080 +``` + +### web client + +[http://127.0.0.1:8080/graphiql](http://127.0.0.1:8080/graphiql) diff --git a/examples/juniper/src/main.rs b/examples/juniper/src/main.rs new file mode 100644 index 000000000..c0be2754e --- /dev/null +++ b/examples/juniper/src/main.rs @@ -0,0 +1,110 @@ +//! Actix web juniper example +//! +//! A simple example integrating juniper in actix-web +extern crate serde; +extern crate serde_json; +#[macro_use] +extern crate serde_derive; +#[macro_use] +extern crate juniper; +extern crate futures; +extern crate actix; +extern crate actix_web; +extern crate env_logger; + +use actix::*; +use actix_web::*; +use juniper::http::graphiql::graphiql_source; +use juniper::http::GraphQLRequest; + +use futures::future::Future; + +mod schema; + +use schema::Schema; +use schema::create_schema; + +struct State { + executor: Addr, +} + +#[derive(Serialize, Deserialize)] +pub struct GraphQLData(GraphQLRequest); + +impl Message for GraphQLData { + type Result = Result; +} + +pub struct GraphQLExecutor { + schema: std::sync::Arc +} + +impl GraphQLExecutor { + fn new(schema: std::sync::Arc) -> GraphQLExecutor { + GraphQLExecutor { + schema: schema, + } + } +} + +impl Actor for GraphQLExecutor { + type Context = SyncContext; +} + +impl Handler for GraphQLExecutor { + type Result = Result; + + fn handle(&mut self, msg: GraphQLData, _: &mut Self::Context) -> Self::Result { + let res = msg.0.execute(&self.schema, &()); + let res_text = serde_json::to_string(&res)?; + Ok(res_text) + } +} + +fn graphiql(_req: HttpRequest) -> Result { + let html = graphiql_source("http://127.0.0.1:8080/graphql"); + Ok(HttpResponse::build(StatusCode::OK) + .content_type("text/html; charset=utf-8") + .body(html).unwrap()) +} + +fn graphql(req: HttpRequest) -> Box> { + let executor = req.state().executor.clone(); + req.json() + .from_err() + .and_then(move |val: GraphQLData| { + executor.send(val) + .from_err() + .and_then(|res| { + match res { + Ok(user) => Ok(httpcodes::HTTPOk.build().body(user)?), + Err(_) => Ok(httpcodes::HTTPInternalServerError.into()) + } + }) + }) + .responder() +} + +fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + let _ = env_logger::init(); + let sys = actix::System::new("juniper-example"); + + let schema = std::sync::Arc::new(create_schema()); + let addr = SyncArbiter::start(3, move || { + GraphQLExecutor::new(schema.clone()) + }); + + // Start http server + let _addr = HttpServer::new(move || { + Application::with_state(State{executor: addr.clone()}) + // enable logger + .middleware(middleware::Logger::default()) + .resource("/graphql", |r| r.method(Method::POST).a(graphql)) + .resource("/graphiql", |r| r.method(Method::GET).f(graphiql))}) + .bind("127.0.0.1:8080").unwrap() + .start(); + + println!("Started http server: 127.0.0.1:8080"); + let _ = sys.run(); +} diff --git a/examples/juniper/src/schema.rs b/examples/juniper/src/schema.rs new file mode 100644 index 000000000..2b4cf3042 --- /dev/null +++ b/examples/juniper/src/schema.rs @@ -0,0 +1,58 @@ +use juniper::FieldResult; +use juniper::RootNode; + +#[derive(GraphQLEnum)] +enum Episode { + NewHope, + Empire, + Jedi, +} + +#[derive(GraphQLObject)] +#[graphql(description = "A humanoid creature in the Star Wars universe")] +struct Human { + id: String, + name: String, + appears_in: Vec, + home_planet: String, +} + +#[derive(GraphQLInputObject)] +#[graphql(description = "A humanoid creature in the Star Wars universe")] +struct NewHuman { + name: String, + appears_in: Vec, + home_planet: String, +} + +pub struct QueryRoot; + +graphql_object!(QueryRoot: () |&self| { + field human(&executor, id: String) -> FieldResult { + Ok(Human{ + id: "1234".to_owned(), + name: "Luke".to_owned(), + appears_in: vec![Episode::NewHope], + home_planet: "Mars".to_owned(), + }) + } +}); + +pub struct MutationRoot; + +graphql_object!(MutationRoot: () |&self| { + field createHuman(&executor, new_human: NewHuman) -> FieldResult { + Ok(Human{ + id: "1234".to_owned(), + name: new_human.name, + appears_in: new_human.appears_in, + home_planet: new_human.home_planet, + }) + } +}); + +pub type Schema = RootNode<'static, QueryRoot, MutationRoot>; + +pub fn create_schema() -> Schema { + Schema::new(QueryRoot {}, MutationRoot {}) +} diff --git a/examples/multipart/src/main.rs b/examples/multipart/src/main.rs index 7da6145a9..343dde167 100644 --- a/examples/multipart/src/main.rs +++ b/examples/multipart/src/main.rs @@ -11,7 +11,7 @@ use futures::{Future, Stream}; use futures::future::{result, Either}; -fn index(mut req: HttpRequest) -> Box> +fn index(req: HttpRequest) -> Box> { println!("{:?}", req); diff --git a/examples/redis-session/Cargo.toml b/examples/redis-session/Cargo.toml new file mode 100644 index 000000000..cfa102d11 --- /dev/null +++ b/examples/redis-session/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "redis-session" +version = "0.1.0" +authors = ["Nikolay Kim "] +workspace = "../.." + +[dependencies] +env_logger = "0.5" +actix = "0.5" +actix-web = "0.4" +actix-redis = { version = "0.2", features = ["web"] } diff --git a/examples/redis-session/src/main.rs b/examples/redis-session/src/main.rs new file mode 100644 index 000000000..36df16559 --- /dev/null +++ b/examples/redis-session/src/main.rs @@ -0,0 +1,48 @@ +#![allow(unused_variables)] + +extern crate actix; +extern crate actix_web; +extern crate actix_redis; +extern crate env_logger; + +use actix_web::*; +use actix_web::middleware::RequestSession; +use actix_redis::RedisSessionBackend; + + +/// simple handler +fn index(mut req: HttpRequest) -> Result { + println!("{:?}", req); + + // session + if let Some(count) = req.session().get::("counter")? { + println!("SESSION value: {}", count); + req.session().set("counter", count+1)?; + } else { + req.session().set("counter", 1)?; + } + + Ok("Welcome!".into()) +} + +fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info,actix_redis=info"); + env_logger::init(); + let sys = actix::System::new("basic-example"); + + HttpServer::new( + || Application::new() + // enable logger + .middleware(middleware::Logger::default()) + // cookie session middleware + .middleware(middleware::SessionStorage::new( + RedisSessionBackend::new("127.0.0.1:6379", &[0; 32]) + )) + // register simple route, handle all methods + .resource("/", |r| r.f(index))) + .bind("0.0.0.0:8080").unwrap() + .threads(1) + .start(); + + let _ = sys.run(); +} diff --git a/examples/state/src/main.rs b/examples/state/src/main.rs index 21eb50483..0f7e0ec3b 100644 --- a/examples/state/src/main.rs +++ b/examples/state/src/main.rs @@ -36,8 +36,7 @@ impl Actor for MyWebSocket { type Context = ws::WebsocketContext; } -impl Handler for MyWebSocket { - type Result = (); +impl StreamHandler for MyWebSocket { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { self.counter += 1; @@ -46,7 +45,7 @@ impl Handler for MyWebSocket { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => ctx.text(text), ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Closed | ws::Message::Error => { + ws::Message::Close(_) => { ctx.stop(); } _ => (), diff --git a/examples/unix-socket/Cargo.toml b/examples/unix-socket/Cargo.toml new file mode 100644 index 000000000..a7c31f212 --- /dev/null +++ b/examples/unix-socket/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "unix-socket" +version = "0.1.0" +authors = ["Messense Lv "] + +[dependencies] +env_logger = "0.5" +actix = "0.5" +actix-web = { path = "../../" } +tokio-uds = "0.1" diff --git a/examples/unix-socket/README.md b/examples/unix-socket/README.md new file mode 100644 index 000000000..03b0066a2 --- /dev/null +++ b/examples/unix-socket/README.md @@ -0,0 +1,14 @@ +## Unix domain socket example + +```bash +$ curl --unix-socket /tmp/actix-uds.socket http://localhost/ +Hello world! +``` + +Although this will only one thread for handling incoming connections +according to the +[documentation](https://actix.github.io/actix-web/actix_web/struct.HttpServer.html#method.start_incoming). + +And it does not delete the socket file (`/tmp/actix-uds.socket`) when stopping +the server so it will fail to start next time you run it unless you delete +the socket file manually. diff --git a/examples/unix-socket/src/main.rs b/examples/unix-socket/src/main.rs new file mode 100644 index 000000000..a56d428a7 --- /dev/null +++ b/examples/unix-socket/src/main.rs @@ -0,0 +1,31 @@ +extern crate actix; +extern crate actix_web; +extern crate env_logger; +extern crate tokio_uds; + +use actix::*; +use actix_web::*; +use tokio_uds::UnixListener; + + +fn index(_req: HttpRequest) -> &'static str { + "Hello world!" +} + +fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + let _ = env_logger::init(); + let sys = actix::System::new("unix-socket"); + + let listener = UnixListener::bind("/tmp/actix-uds.socket", Arbiter::handle()).expect("bind failed"); + let _addr = HttpServer::new( + || Application::new() + // enable logger + .middleware(middleware::Logger::default()) + .resource("/index.html", |r| r.f(|_| "Hello world!")) + .resource("/", |r| r.f(index))) + .start_incoming(listener.incoming(), false); + + println!("Started http server: /tmp/actix-uds.socket"); + let _ = sys.run(); +} diff --git a/examples/websocket-chat/src/main.rs b/examples/websocket-chat/src/main.rs index 821fcfa57..dccd768aa 100644 --- a/examples/websocket-chat/src/main.rs +++ b/examples/websocket-chat/src/main.rs @@ -92,8 +92,7 @@ impl Handler for WsChatSession { } /// WebSocket message handler -impl Handler for WsChatSession { - type Result = (); +impl StreamHandler for WsChatSession { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { println!("WEBSOCKET MESSAGE: {:?}", msg); @@ -161,10 +160,9 @@ impl Handler for WsChatSession { }, ws::Message::Binary(bin) => println!("Unexpected binary"), - ws::Message::Closed | ws::Message::Error => { + ws::Message::Close(_) => { ctx.stop(); } - _ => (), } } } diff --git a/examples/websocket/src/client.rs b/examples/websocket/src/client.rs index e35c71bb1..34ff24372 100644 --- a/examples/websocket/src/client.rs +++ b/examples/websocket/src/client.rs @@ -12,7 +12,7 @@ use std::time::Duration; use actix::*; use futures::Future; -use actix_web::ws::{Message, WsClientError, WsClient, WsClientWriter}; +use actix_web::ws::{Message, ProtocolError, Client, ClientWriter}; fn main() { @@ -21,8 +21,8 @@ fn main() { let sys = actix::System::new("ws-example"); Arbiter::handle().spawn( - WsClient::new("http://127.0.0.1:8080/ws/") - .connect().unwrap() + Client::new("http://127.0.0.1:8080/ws/") + .connect() .map_err(|e| { println!("Error: {}", e); () @@ -53,7 +53,7 @@ fn main() { } -struct ChatClient(WsClientWriter); +struct ChatClient(ClientWriter); #[derive(Message)] struct ClientCommand(String); @@ -88,12 +88,12 @@ impl Handler for ChatClient { type Result = (); fn handle(&mut self, msg: ClientCommand, ctx: &mut Context) { - self.0.text(msg.0.as_str()) + self.0.text(msg.0) } } /// Handle server websocket messages -impl StreamHandler for ChatClient { +impl StreamHandler for ChatClient { fn handle(&mut self, msg: Message, ctx: &mut Context) { match msg { diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs index 9149ead71..f97b948de 100644 --- a/examples/websocket/src/main.rs +++ b/examples/websocket/src/main.rs @@ -25,8 +25,7 @@ impl Actor for MyWebSocket { } /// Handler for `ws::Message` -impl Handler for MyWebSocket { - type Result = (); +impl StreamHandler for MyWebSocket { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { // process websocket messages @@ -35,7 +34,7 @@ impl Handler for MyWebSocket { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => ctx.text(text), ws::Message::Binary(bin) => ctx.binary(bin), - ws::Message::Closed | ws::Message::Error => { + ws::Message::Close(_) => { ctx.stop(); } _ => (), diff --git a/guide/src/qs_10.md b/guide/src/qs_10.md index ed36140c7..3e007bcab 100644 --- a/guide/src/qs_10.md +++ b/guide/src/qs_10.md @@ -53,7 +53,7 @@ impl Middleware for Headers { fn main() { Application::new() .middleware(Headers) // <- Register middleware, this method could be called multiple times - .resource("/", |r| r.h(httpcodes::HTTPOk)); + .resource("/", |r| r.h(httpcodes::HttpOk)); } ``` @@ -144,8 +144,8 @@ fn main() { .header("X-Version", "0.2") .finish()) .resource("/test", |r| { - r.method(Method::GET).f(|req| httpcodes::HTTPOk); - r.method(Method::HEAD).f(|req| httpcodes::HTTPMethodNotAllowed); + r.method(Method::GET).f(|req| httpcodes::HttpOk); + r.method(Method::HEAD).f(|req| httpcodes::HttpMethodNotAllowed); }) .finish(); } diff --git a/guide/src/qs_14.md b/guide/src/qs_14.md index c318bcaad..72827e4eb 100644 --- a/guide/src/qs_14.md +++ b/guide/src/qs_14.md @@ -12,7 +12,7 @@ We have to define sync actor and connection that this actor will use. Same appro could be used for other databases. ```rust,ignore -use actix::prelude::*;* +use actix::prelude::*; struct DbExecutor(SqliteConnection); @@ -36,7 +36,7 @@ We can send `CreateUser` message to `DbExecutor` actor, and as result we get ```rust,ignore impl Handler for DbExecutor { - type Result = Result + type Result = Result; fn handle(&mut self, msg: CreateUser, _: &mut Self::Context) -> Self::Result { @@ -110,8 +110,8 @@ fn index(req: HttpRequest) -> Box> .from_err() .and_then(|res| { match res { - Ok(user) => Ok(httpcodes::HTTPOk.build().json(user)?), - Err(_) => Ok(httpcodes::HTTPInternalServerError.into()) + Ok(user) => Ok(httpcodes::HttpOk.build().json(user)?), + Err(_) => Ok(httpcodes::HttpInternalServerError.into()) } }) .responder() diff --git a/guide/src/qs_2.md b/guide/src/qs_2.md index ac53d8707..01cb98499 100644 --- a/guide/src/qs_2.md +++ b/guide/src/qs_2.md @@ -20,8 +20,8 @@ contains the following: ```toml [dependencies] -actix = "0.4" -actix-web = "0.3" +actix = "0.5" +actix-web = "0.4" ``` In order to implement a web server, first we need to create a request handler. diff --git a/guide/src/qs_3.md b/guide/src/qs_3.md index 6d9c1a426..341b62cc0 100644 --- a/guide/src/qs_3.md +++ b/guide/src/qs_3.md @@ -49,12 +49,12 @@ fn main() { HttpServer::new(|| vec![ Application::new() .prefix("/app1") - .resource("/", |r| r.f(|r| httpcodes::HTTPOk)), + .resource("/", |r| r.f(|r| httpcodes::HttpOk)), Application::new() .prefix("/app2") - .resource("/", |r| r.f(|r| httpcodes::HTTPOk)), + .resource("/", |r| r.f(|r| httpcodes::HttpOk)), Application::new() - .resource("/", |r| r.f(|r| httpcodes::HTTPOk)), + .resource("/", |r| r.f(|r| httpcodes::HttpOk)), ]); } ``` diff --git a/guide/src/qs_3_5.md b/guide/src/qs_3_5.md index 99c2bcd9a..3f1fff00e 100644 --- a/guide/src/qs_3_5.md +++ b/guide/src/qs_3_5.md @@ -20,7 +20,7 @@ fn main() { HttpServer::new( || Application::new() - .resource("/", |r| r.h(httpcodes::HTTPOk))) + .resource("/", |r| r.h(httpcodes::HttpOk))) .bind("127.0.0.1:59080").unwrap() .start(); @@ -57,7 +57,7 @@ fn main() { let sys = actix::System::new("http-server"); let addr = HttpServer::new( || Application::new() - .resource("/", |r| r.h(httpcodes::HTTPOk))) + .resource("/", |r| r.h(httpcodes::HttpOk))) .bind("127.0.0.1:0").expect("Can not bind to 127.0.0.1:0") .shutdown_timeout(60) // <- Set shutdown timeout to 60 seconds .start(); @@ -85,7 +85,7 @@ use actix_web::*; fn main() { HttpServer::new( || Application::new() - .resource("/", |r| r.h(httpcodes::HTTPOk))) + .resource("/", |r| r.h(httpcodes::HttpOk))) .threads(4); // <- Start 4 workers } ``` @@ -146,7 +146,7 @@ use actix_web::*; fn main() { HttpServer::new(|| Application::new() - .resource("/", |r| r.h(httpcodes::HTTPOk))) + .resource("/", |r| r.h(httpcodes::HttpOk))) .keep_alive(None); // <- Use `SO_KEEPALIVE` socket option. } ``` @@ -155,7 +155,7 @@ If first option is selected then *keep alive* state calculated based on response's *connection-type*. By default `HttpResponse::connection_type` is not defined in that case *keep alive* defined by request's http version. Keep alive is off for *HTTP/1.0* -and is on for *HTTP/1.1* and "HTTP/2.0". +and is on for *HTTP/1.1* and *HTTP/2.0*. *Connection type* could be change with `HttpResponseBuilder::connection_type()` method. @@ -165,7 +165,7 @@ and is on for *HTTP/1.1* and "HTTP/2.0". use actix_web::*; fn index(req: HttpRequest) -> HttpResponse { - HTTPOk.build() + HttpOk.build() .connection_type(headers::ConnectionType::Close) // <- Close connection .force_close() // <- Alternative method .finish().unwrap() diff --git a/guide/src/qs_4.md b/guide/src/qs_4.md index c7cbc6c94..486e9df58 100644 --- a/guide/src/qs_4.md +++ b/guide/src/qs_4.md @@ -65,7 +65,7 @@ impl Handler for MyHandler { /// Handle request fn handle(&mut self, req: HttpRequest) -> Self::Result { self.0 += 1; - httpcodes::HTTPOk.into() + httpcodes::HttpOk.into() } } # fn main() {} @@ -90,7 +90,7 @@ impl Handler for MyHandler { /// Handle request fn handle(&mut self, req: HttpRequest) -> Self::Result { self.0.fetch_add(1, Ordering::Relaxed); - httpcodes::HTTPOk.into() + httpcodes::HttpOk.into() } } diff --git a/guide/src/qs_4_5.md b/guide/src/qs_4_5.md index 5a11af733..01808c605 100644 --- a/guide/src/qs_4_5.md +++ b/guide/src/qs_4_5.md @@ -14,7 +14,7 @@ impl> Responder for Result And any error that implements `ResponseError` can be converted into `Error` object. For example if *handler* function returns `io::Error`, it would be converted -into `HTTPInternalServerError` response. Implementation for `io::Error` is provided +into `HttpInternalServerError` response. Implementation for `io::Error` is provided by default. ```rust diff --git a/guide/src/qs_5.md b/guide/src/qs_5.md index c1e8b615c..21b2f8c64 100644 --- a/guide/src/qs_5.md +++ b/guide/src/qs_5.md @@ -32,7 +32,7 @@ fn main() { Application::new() .resource("/prefix", |r| r.f(index)) .resource("/user/{name}", - |r| r.method(Method::GET).f(|req| HTTPOk)) + |r| r.method(Method::GET).f(|req| HttpOk)) .finish(); } ``` @@ -52,7 +52,7 @@ returns *NOT FOUND* http resources. Resource contains set of routes. Each route in turn has set of predicates and handler. New route could be created with `Resource::route()` method which returns reference to new *Route* instance. By default *route* does not contain any predicates, so matches -all requests and default handler is `HTTPNotFound`. +all requests and default handler is `HttpNotFound`. Application routes incoming requests based on route criteria which is defined during resource registration and route registration. Resource matches all routes it contains in @@ -68,9 +68,9 @@ fn main() { Application::new() .resource("/path", |resource| resource.route() - .p(pred::Get()) - .p(pred::Header("content-type", "text/plain")) - .f(|req| HTTPOk) + .filter(pred::Get()) + .filter(pred::Header("content-type", "text/plain")) + .f(|req| HttpOk) ) .finish(); } @@ -85,7 +85,7 @@ If resource can not match any route "NOT FOUND" response get returned. [*Route*](../actix_web/struct.Route.html) object. Route can be configured with builder-like pattern. Following configuration methods are available: -* [*Route::p()*](../actix_web/struct.Route.html#method.p) method registers new predicate, +* [*Route::filter()*](../actix_web/struct.Route.html#method.filter) method registers new predicate, any number of predicates could be registered for each route. * [*Route::f()*](../actix_web/struct.Route.html#method.f) method registers handler function @@ -336,14 +336,14 @@ resource with the name "foo" and the pattern "{a}/{b}/{c}", you might do this. # fn index(req: HttpRequest) -> HttpResponse { let url = req.url_for("foo", &["1", "2", "3"]); // <- generate url for "foo" resource - HTTPOk.into() + HttpOk.into() } fn main() { let app = Application::new() .resource("/test/{a}/{b}/{c}", |r| { r.name("foo"); // <- set resource name, then it could be used in `url_for` - r.method(Method::GET).f(|_| httpcodes::HTTPOk); + r.method(Method::GET).f(|_| httpcodes::HttpOk); }) .finish(); } @@ -367,7 +367,7 @@ use actix_web::*; fn index(mut req: HttpRequest) -> Result { let url = req.url_for("youtube", &["oHg5SJYRHA0"])?; assert_eq!(url.as_str(), "https://youtube.com/watch/oHg5SJYRHA0"); - Ok(httpcodes::HTTPOk.into()) + Ok(httpcodes::HttpOk.into()) } fn main() { @@ -404,7 +404,7 @@ This handler designed to be use as a handler for application's *default resource # use actix_web::*; # # fn index(req: HttpRequest) -> httpcodes::StaticResponse { -# httpcodes::HTTPOk +# httpcodes::HttpOk # } fn main() { let app = Application::new() @@ -429,7 +429,7 @@ It is possible to register path normalization only for *GET* requests only # use actix_web::*; # # fn index(req: HttpRequest) -> httpcodes::StaticResponse { -# httpcodes::HTTPOk +# httpcodes::HttpOk # } fn main() { let app = Application::new() @@ -502,8 +502,8 @@ fn main() { Application::new() .resource("/index.html", |r| r.route() - .p(ContentTypeHeader) - .h(HTTPOk)); + .filter(ContentTypeHeader) + .h(HttpOk)); } ``` @@ -530,8 +530,8 @@ fn main() { Application::new() .resource("/index.html", |r| r.route() - .p(pred::Not(pred::Get())) - .f(|req| HTTPMethodNotAllowed)) + .filter(pred::Not(pred::Get())) + .f(|req| HttpMethodNotAllowed)) .finish(); } ``` @@ -567,8 +567,8 @@ use actix_web::httpcodes::*; fn main() { Application::new() .default_resource(|r| { - r.method(Method::GET).f(|req| HTTPNotFound); - r.route().p(pred::Not(pred::Get())).f(|req| HTTPMethodNotAllowed); + r.method(Method::GET).f(|req| HttpNotFound); + r.route().filter(pred::Not(pred::Get())).f(|req| HttpMethodNotAllowed); }) # .finish(); } diff --git a/guide/src/qs_7.md b/guide/src/qs_7.md index 3a96529a0..e7c6bc88b 100644 --- a/guide/src/qs_7.md +++ b/guide/src/qs_7.md @@ -84,7 +84,7 @@ fn index(mut req: HttpRequest) -> Box> { req.json().from_err() .and_then(|val: MyObj| { println!("model: {:?}", val); - Ok(httpcodes::HTTPOk.build().json(val)?) // <- send response + Ok(httpcodes::HttpOk.build().json(val)?) // <- send response }) .responder() } @@ -106,10 +106,10 @@ use futures::{Future, Stream}; #[derive(Serialize, Deserialize)] struct MyObj {name: String, number: i32} -fn index(mut req: HttpRequest) -> Box> { +fn index(req: HttpRequest) -> Box> { // `concat2` will asynchronously read each chunk of the request body and // return a single, concatenated, chunk - req.payload_mut().readany().concat2() + req.concat2() // `Future::from_err` acts like `?` in that it coerces the error type from // the future into the final error type .from_err() @@ -117,7 +117,7 @@ fn index(mut req: HttpRequest) -> Box> { // synchronous workflow .and_then(|body| { // <- body is loaded, now we can deserialize json let obj = serde_json::from_slice::(&body)?; - Ok(httpcodes::HTTPOk.build().json(obj)?) // <- send response + Ok(httpcodes::HttpOk.build().json(obj)?) // <- send response }) .responder() } @@ -169,13 +169,18 @@ get enabled automatically. Enabling chunked encoding for *HTTP/2.0* responses is forbidden. ```rust +# extern crate bytes; # extern crate actix_web; +# extern crate futures; +# use futures::Stream; use actix_web::*; +use bytes::Bytes; +use futures::stream::once; fn index(req: HttpRequest) -> HttpResponse { HttpResponse::Ok() .chunked() - .body(Body::Streaming(payload::Payload::empty().stream())).unwrap() + .body(Body::Streaming(Box::new(once(Ok(Bytes::from_static(b"data")))))).unwrap() } # fn main() {} ``` @@ -246,7 +251,7 @@ fn index(mut req: HttpRequest) -> Box> { .from_err() .and_then(|params| { // <- url encoded parameters println!("==== BODY ==== {:?}", params); - ok(httpcodes::HTTPOk.into()) + ok(httpcodes::HttpOk.into()) }) .responder() } @@ -256,21 +261,8 @@ fn index(mut req: HttpRequest) -> Box> { ## Streaming request -Actix uses [*Payload*](../actix_web/payload/struct.Payload.html) object as request payload stream. -*HttpRequest* provides several methods, which can be used for payload access. -At the same time *Payload* implements *Stream* trait, so it could be used with various -stream combinators. Also *Payload* provides several convenience methods that return -future object that resolve to Bytes object. - -* *readany()* method returns *Stream* of *Bytes* objects. - -* *readexactly()* method returns *Future* that resolves when specified number of bytes - get received. - -* *readline()* method returns *Future* that resolves when `\n` get received. - -* *readuntil()* method returns *Future* that resolves when specified bytes string - matches in input bytes stream +*HttpRequest* is a stream of `Bytes` objects. It could be used to read request +body payload. In this example handle reads request payload chunk by chunk and prints every chunk. @@ -283,9 +275,7 @@ use futures::{Future, Stream}; fn index(mut req: HttpRequest) -> Box> { - req.payload() - .readany() - .from_err() + req.from_err() .fold((), |_, chunk| { println!("Chunk: {:?}", chunk); result::<_, error::PayloadError>(Ok(())) diff --git a/guide/src/qs_8.md b/guide/src/qs_8.md index 2e2b54201..74e7421d2 100644 --- a/guide/src/qs_8.md +++ b/guide/src/qs_8.md @@ -20,10 +20,10 @@ use actix_web::test::TestRequest; fn index(req: HttpRequest) -> HttpResponse { if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { if let Ok(s) = hdr.to_str() { - return httpcodes::HTTPOk.into() + return httpcodes::HttpOk.into() } } - httpcodes::HTTPBadRequest.into() + httpcodes::HttpBadRequest.into() } fn main() { @@ -59,16 +59,16 @@ use actix_web::*; use actix_web::test::TestServer; fn index(req: HttpRequest) -> HttpResponse { - httpcodes::HTTPOk.into() + httpcodes::HttpOk.into() } fn main() { let mut srv = TestServer::new(|app| app.handler(index)); // <- Start new test server - + let request = srv.get().finish().unwrap(); // <- create client request let response = srv.execute(request.send()).unwrap(); // <- send request to the server assert!(response.status().is_success()); // <- check response - + let bytes = srv.execute(response.body()).unwrap(); // <- read response body } ``` @@ -84,7 +84,7 @@ use actix_web::*; use actix_web::test::TestServer; fn index(req: HttpRequest) -> HttpResponse { - httpcodes::HTTPOk.into() + httpcodes::HttpOk.into() } /// This function get called by http server. @@ -130,8 +130,7 @@ impl Actor for Ws { type Context = ws::WebsocketContext; } -impl Handler for Ws { - type Result = (); +impl StreamHandler for Ws { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { match msg { diff --git a/guide/src/qs_9.md b/guide/src/qs_9.md index dbca38384..fa8b979ae 100644 --- a/guide/src/qs_9.md +++ b/guide/src/qs_9.md @@ -21,9 +21,8 @@ impl Actor for Ws { type Context = ws::WebsocketContext; } -/// Define Handler for ws::Message message -impl Handler for Ws { - type Result=(); +/// Handler for ws::Message message +impl StreamHandler for Ws { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { match msg { @@ -43,7 +42,7 @@ fn main() { ``` Simple websocket echo server example is available in -[examples directory](https://github.com/actix/actix-web/blob/master/examples/websocket.rs). +[examples directory](https://github.com/actix/actix-web/blob/master/examples/websocket). Example chat server with ability to chat over websocket connection or tcp connection is available in [websocket-chat directory](https://github.com/actix/actix-web/tree/master/examples/websocket-chat/) diff --git a/src/application.rs b/src/application.rs index c7c1bcacb..9f0e399b3 100644 --- a/src/application.rs +++ b/src/application.rs @@ -149,7 +149,7 @@ impl Application where S: 'static { pub fn with_state(state: S) -> Application { Application { parts: Some(ApplicationParts { - state: state, + state, prefix: "/".to_owned(), settings: ServerSettings::default(), default: Resource::default_not_found(), @@ -183,8 +183,8 @@ impl Application where S: 'static { /// let app = Application::new() /// .prefix("/app") /// .resource("/test", |r| { - /// r.method(Method::GET).f(|_| httpcodes::HTTPOk); - /// r.method(Method::HEAD).f(|_| httpcodes::HTTPMethodNotAllowed); + /// r.method(Method::GET).f(|_| httpcodes::HttpOk); + /// r.method(Method::HEAD).f(|_| httpcodes::HttpMethodNotAllowed); /// }) /// .finish(); /// } @@ -226,8 +226,8 @@ impl Application where S: 'static { /// fn main() { /// let app = Application::new() /// .resource("/test", |r| { - /// r.method(Method::GET).f(|_| httpcodes::HTTPOk); - /// r.method(Method::HEAD).f(|_| httpcodes::HTTPMethodNotAllowed); + /// r.method(Method::GET).f(|_| httpcodes::HttpOk); + /// r.method(Method::HEAD).f(|_| httpcodes::HttpMethodNotAllowed); /// }); /// } /// ``` @@ -281,7 +281,7 @@ impl Application where S: 'static { /// fn index(mut req: HttpRequest) -> Result { /// let url = req.url_for("youtube", &["oHg5SJYRHA0"])?; /// assert_eq!(url.as_str(), "https://youtube.com/watch/oHg5SJYRHA0"); - /// Ok(httpcodes::HTTPOk.into()) + /// Ok(httpcodes::HttpOk.into()) /// } /// /// fn main() { @@ -320,9 +320,9 @@ impl Application where S: 'static { /// let app = Application::new() /// .handler("/app", |req: HttpRequest| { /// match *req.method() { - /// Method::GET => httpcodes::HTTPOk, - /// Method::POST => httpcodes::HTTPMethodNotAllowed, - /// _ => httpcodes::HTTPNotFound, + /// Method::GET => httpcodes::HttpOk, + /// Method::POST => httpcodes::HttpMethodNotAllowed, + /// _ => httpcodes::HttpNotFound, /// }}); /// } /// ``` @@ -361,17 +361,17 @@ impl Application where S: 'static { default: parts.default, encoding: parts.encoding, router: router.clone(), - resources: resources, handlers: parts.handlers, + resources, } )); HttpApplication { state: Rc::new(parts.state), prefix: prefix.to_owned(), - inner: inner, router: router.clone(), middlewares: Rc::new(parts.middlewares), + inner, } } @@ -394,11 +394,11 @@ impl Application where S: 'static { /// HttpServer::new(|| { vec![ /// Application::with_state(State1) /// .prefix("/app1") - /// .resource("/", |r| r.h(httpcodes::HTTPOk)) + /// .resource("/", |r| r.h(httpcodes::HttpOk)) /// .boxed(), /// Application::with_state(State2) /// .prefix("/app2") - /// .resource("/", |r| r.h(httpcodes::HTTPOk)) + /// .resource("/", |r| r.h(httpcodes::HttpOk)) /// .boxed() ]}) /// .bind("127.0.0.1:8080").unwrap() /// .run() @@ -459,7 +459,7 @@ mod tests { #[test] fn test_default_resource() { let mut app = Application::new() - .resource("/test", |r| r.h(httpcodes::HTTPOk)) + .resource("/test", |r| r.h(httpcodes::HttpOk)) .finish(); let req = TestRequest::with_uri("/test").finish(); @@ -471,7 +471,7 @@ mod tests { assert_eq!(resp.as_response().unwrap().status(), StatusCode::NOT_FOUND); let mut app = Application::new() - .default_resource(|r| r.h(httpcodes::HTTPMethodNotAllowed)) + .default_resource(|r| r.h(httpcodes::HttpMethodNotAllowed)) .finish(); let req = TestRequest::with_uri("/blah").finish(); let resp = app.run(req); @@ -482,7 +482,7 @@ mod tests { fn test_unhandled_prefix() { let mut app = Application::new() .prefix("/test") - .resource("/test", |r| r.h(httpcodes::HTTPOk)) + .resource("/test", |r| r.h(httpcodes::HttpOk)) .finish(); assert!(app.handle(HttpRequest::default()).is_err()); } @@ -490,7 +490,7 @@ mod tests { #[test] fn test_state() { let mut app = Application::with_state(10) - .resource("/", |r| r.h(httpcodes::HTTPOk)) + .resource("/", |r| r.h(httpcodes::HttpOk)) .finish(); let req = HttpRequest::default().with_state(Rc::clone(&app.state), app.router.clone()); let resp = app.run(req); @@ -501,7 +501,7 @@ mod tests { fn test_prefix() { let mut app = Application::new() .prefix("/test") - .resource("/blah", |r| r.h(httpcodes::HTTPOk)) + .resource("/blah", |r| r.h(httpcodes::HttpOk)) .finish(); let req = TestRequest::with_uri("/test").finish(); let resp = app.handle(req); @@ -523,7 +523,7 @@ mod tests { #[test] fn test_handler() { let mut app = Application::new() - .handler("/test", httpcodes::HTTPOk) + .handler("/test", httpcodes::HttpOk) .finish(); let req = TestRequest::with_uri("/test").finish(); @@ -551,7 +551,7 @@ mod tests { fn test_handler_prefix() { let mut app = Application::new() .prefix("/app") - .handler("/test", httpcodes::HTTPOk) + .handler("/test", httpcodes::HttpOk) .finish(); let req = TestRequest::with_uri("/test").finish(); diff --git a/src/body.rs b/src/body.rs index ebd011e9c..fe6303438 100644 --- a/src/body.rs +++ b/src/body.rs @@ -36,6 +36,8 @@ pub enum Binary { /// Shared string body #[doc(hidden)] ArcSharedString(Arc), + /// Shared vec body + SharedVec(Arc>), } impl Body { @@ -115,6 +117,7 @@ impl Binary { Binary::Slice(slice) => slice.len(), Binary::SharedString(ref s) => s.len(), Binary::ArcSharedString(ref s) => s.len(), + Binary::SharedVec(ref s) => s.len(), } } @@ -134,8 +137,9 @@ impl Clone for Binary { match *self { Binary::Bytes(ref bytes) => Binary::Bytes(bytes.clone()), Binary::Slice(slice) => Binary::Bytes(Bytes::from(slice)), - Binary::SharedString(ref s) => Binary::Bytes(Bytes::from(s.as_str())), - Binary::ArcSharedString(ref s) => Binary::Bytes(Bytes::from(s.as_str())), + Binary::SharedString(ref s) => Binary::SharedString(s.clone()), + Binary::ArcSharedString(ref s) => Binary::ArcSharedString(s.clone()), + Binary::SharedVec(ref s) => Binary::SharedVec(s.clone()), } } } @@ -147,6 +151,7 @@ impl Into for Binary { Binary::Slice(slice) => Bytes::from(slice), Binary::SharedString(s) => Bytes::from(s.as_str()), Binary::ArcSharedString(s) => Bytes::from(s.as_str()), + Binary::SharedVec(s) => Bytes::from(AsRef::<[u8]>::as_ref(s.as_ref())), } } } @@ -217,6 +222,18 @@ impl<'a> From<&'a Arc> for Binary { } } +impl From>> for Binary { + fn from(body: Arc>) -> Binary { + Binary::SharedVec(body) + } +} + +impl<'a> From<&'a Arc>> for Binary { + fn from(body: &'a Arc>) -> Binary { + Binary::SharedVec(Arc::clone(body)) + } +} + impl AsRef<[u8]> for Binary { fn as_ref(&self) -> &[u8] { match *self { @@ -224,6 +241,7 @@ impl AsRef<[u8]> for Binary { Binary::Slice(slice) => slice, Binary::SharedString(ref s) => s.as_bytes(), Binary::ArcSharedString(ref s) => s.as_bytes(), + Binary::SharedVec(ref s) => s.as_ref().as_ref(), } } } @@ -304,6 +322,15 @@ mod tests { assert_eq!(Binary::from(&b).as_ref(), "test".as_bytes()); } + #[test] + fn test_shared_vec() { + let b = Arc::new(Vec::from(&b"test"[..])); + assert_eq!(Binary::from(b.clone()).len(), 4); + assert_eq!(Binary::from(b.clone()).as_ref(), &b"test"[..]); + assert_eq!(Binary::from(&b).len(), 4); + assert_eq!(Binary::from(&b).as_ref(), &b"test"[..]); + } + #[test] fn test_bytes_mut() { let b = BytesMut::from("test"); diff --git a/src/client/connector.rs b/src/client/connector.rs index 7acd4ed28..4e8ac214b 100644 --- a/src/client/connector.rs +++ b/src/client/connector.rs @@ -1,25 +1,22 @@ -#![allow(unused_imports, dead_code)] use std::{io, time}; -use std::net::{SocketAddr, Shutdown}; -use std::collections::VecDeque; -use std::time::Duration; +use std::net::Shutdown; -use actix::{fut, Actor, ActorFuture, Arbiter, Context, +use actix::{fut, Actor, ActorFuture, Context, Handler, Message, ActorResponse, Supervised}; use actix::registry::ArbiterService; use actix::fut::WrapFuture; use actix::actors::{Connector, ConnectorError, Connect as ResolveConnect}; use http::{Uri, HttpTryFrom, Error as HttpError}; -use futures::{Async, Future, Poll}; -use tokio_core::reactor::Timeout; -use tokio_core::net::{TcpStream, TcpStreamNew}; +use futures::Poll; use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature="alpn")] -use openssl::ssl::{SslMethod, SslConnector, SslVerifyMode, Error as OpensslError}; +use openssl::ssl::{SslMethod, SslConnector, Error as OpensslError}; #[cfg(feature="alpn")] use tokio_openssl::SslConnectorExt; +#[cfg(feature="alpn")] +use futures::Future; use HAS_OPENSSL; use server::IoStream; @@ -97,7 +94,7 @@ impl Default for ClientConnector { fn default() -> ClientConnector { #[cfg(feature="alpn")] { - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + let builder = SslConnector::builder(SslMethod::tls()).unwrap(); ClientConnector { connector: builder.build() } @@ -154,9 +151,7 @@ impl ClientConnector { /// } /// ``` pub fn with_connector(connector: SslConnector) -> ClientConnector { - ClientConnector { - connector: connector - } + ClientConnector { connector } } } @@ -201,7 +196,7 @@ impl Handler for ClientConnector { if proto.is_secure() { fut::Either::A( _act.connector.connect_async(&host, stream) - .map_err(|e| ClientConnectorError::SslError(e)) + .map_err(ClientConnectorError::SslError) .map(|stream| Connection{stream: Box::new(stream)}) .into_actor(_act)) } else { diff --git a/src/client/mod.rs b/src/client/mod.rs index f7b735437..a4ee14178 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,7 +8,24 @@ mod writer; pub use self::pipeline::{SendRequest, SendRequestError}; pub use self::request::{ClientRequest, ClientRequestBuilder}; -pub use self::response::{ClientResponse, ResponseBody, JsonResponse, UrlEncoded}; +pub use self::response::ClientResponse; pub use self::connector::{Connect, Connection, ClientConnector, ClientConnectorError}; pub(crate) use self::writer::HttpClientWriter; pub(crate) use self::parser::{HttpResponseParser, HttpResponseParserError}; + + +use httpcodes; +use httpresponse::HttpResponse; +use error::ResponseError; + + +/// Convert `SendRequestError` to a `HttpResponse` +impl ResponseError for SendRequestError { + + fn error_response(&self) -> HttpResponse { + match *self { + SendRequestError::Connector(_) => httpcodes::HttpBadGateway.into(), + _ => httpcodes::HttpInternalServerError.into(), + } + } +} diff --git a/src/client/parser.rs b/src/client/parser.rs index b4ce9b2b2..8fe399009 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -37,25 +37,22 @@ impl HttpResponseParser { where T: IoStream { // if buf is empty parse_message will always return NotReady, let's avoid that - let read = if buf.is_empty() { + if buf.is_empty() { match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - // debug!("Ignored premature client disconnection"); - return Err(HttpResponseParserError::Disconnect); - }, + Ok(Async::Ready(0)) => + return Err(HttpResponseParserError::Disconnect), Ok(Async::Ready(_)) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => return Err(HttpResponseParserError::Error(err.into())) } - false - } else { - true - }; + } loop { - match HttpResponseParser::parse_message(buf).map_err(HttpResponseParserError::Error)? { + match HttpResponseParser::parse_message(buf) + .map_err(HttpResponseParserError::Error)? + { Async::Ready((msg, decoder)) => { self.decoder = decoder; return Ok(Async::Ready(msg)); @@ -64,15 +61,13 @@ impl HttpResponseParser { if buf.capacity() >= MAX_BUFFER_SIZE { return Err(HttpResponseParserError::Error(ParseError::TooLarge)); } - if read { - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => return Err(HttpResponseParserError::Disconnect), - Ok(Async::Ready(_)) => (), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => return Err(HttpResponseParserError::Error(err.into())), - } - } else { - return Ok(Async::NotReady) + match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => + return Err(HttpResponseParserError::Disconnect), + Ok(Async::Ready(_)) => (), + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => + return Err(HttpResponseParserError::Error(err.into())), } }, } @@ -83,20 +78,44 @@ impl HttpResponseParser { -> Poll, PayloadError> where T: IoStream { - if let Some(ref mut decoder) = self.decoder { - // read payload - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => return Err(PayloadError::Incomplete), - Err(err) => return Err(err.into()), - _ => (), + if self.decoder.is_some() { + loop { + // read payload + let not_ready = match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => { + if buf.is_empty() { + return Err(PayloadError::Incomplete) + } + true + } + Err(err) => return Err(err.into()), + Ok(Async::NotReady) => true, + _ => false, + }; + + match self.decoder.as_mut().unwrap().decode(buf) { + Ok(Async::Ready(Some(b))) => + return Ok(Async::Ready(Some(b))), + Ok(Async::Ready(None)) => { + self.decoder.take(); + return Ok(Async::Ready(None)) + } + Ok(Async::NotReady) => { + if not_ready { + return Ok(Async::NotReady) + } + } + Err(err) => return Err(err.into()), + } } - decoder.decode(buf).map_err(|e| e.into()) } else { Ok(Async::Ready(None)) } } - fn parse_message(buf: &mut BytesMut) -> Poll<(ClientResponse, Option), ParseError> { + fn parse_message(buf: &mut BytesMut) + -> Poll<(ClientResponse, Option), ParseError> + { // Parse http message let bytes_ptr = buf.as_ref().as_ptr() as usize; let mut headers: [httparse::Header; MAX_HEADERS] = @@ -137,7 +156,9 @@ impl HttpResponseParser { } } - let decoder = if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { + let decoder = if status == StatusCode::SWITCHING_PROTOCOLS { + Some(Decoder::eof()) + } else if let Some(len) = hdrs.get(header::CONTENT_LENGTH) { // Content-Length if let Ok(s) = len.to_str() { if let Ok(len) = s.parse::() { @@ -153,25 +174,19 @@ impl HttpResponseParser { } else if chunked(&hdrs)? { // Chunked encoding Some(Decoder::chunked()) - } else if hdrs.contains_key(header::UPGRADE) { - Some(Decoder::eof()) } else { None }; if let Some(decoder) = decoder { - //let info = PayloadInfo { - //tx: PayloadType::new(&hdrs, psender), - // decoder: decoder, - //}; Ok(Async::Ready( (ClientResponse::new( - ClientMessage{status: status, version: version, + ClientMessage{status, version, headers: hdrs, cookies: None}), Some(decoder)))) } else { Ok(Async::Ready( (ClientResponse::new( - ClientMessage{status: status, version: version, + ClientMessage{status, version, headers: hdrs, cookies: None}), None))) } } diff --git a/src/client/pipeline.rs b/src/client/pipeline.rs index d0f339d7f..baa84da9d 100644 --- a/src/client/pipeline.rs +++ b/src/client/pipeline.rs @@ -1,5 +1,6 @@ use std::{io, mem}; use bytes::{Bytes, BytesMut}; +use http::header::CONTENT_ENCODING; use futures::{Async, Future, Poll}; use futures::unsync::oneshot; @@ -8,9 +9,12 @@ use actix::prelude::*; use error::Error; use body::{Body, BodyStream}; use context::{Frame, ActorHttpContext}; +use headers::ContentEncoding; +use httpmessage::HttpMessage; use error::PayloadError; use server::WriterState; use server::shared::SharedBytes; +use server::encoding::PayloadStream; use super::{ClientRequest, ClientResponse}; use super::{Connect, Connection, ClientConnector, ClientConnectorError}; use super::HttpClientWriter; @@ -60,16 +64,13 @@ impl SendRequest { pub(crate) fn with_connector(req: ClientRequest, conn: Addr) -> SendRequest { - SendRequest{ - req: req, - state: State::New, - conn: conn} + SendRequest{req, conn, state: State::New} } pub(crate) fn with_connection(req: ClientRequest, conn: Connection) -> SendRequest { SendRequest{ - req: req, + req, state: State::Connection(conn), conn: ClientConnector::from_registry()} } @@ -100,7 +101,7 @@ impl Future for SendRequest { Err(_) => return Err(SendRequestError::Connector( ClientConnectorError::Disconnected)) }, - State::Connection(stream) => { + State::Connection(conn) => { let mut writer = HttpClientWriter::new(SharedBytes::default()); writer.start(&mut self.req)?; @@ -110,15 +111,15 @@ impl Future for SendRequest { _ => IoBody::Done, }; - let mut pl = Box::new(Pipeline { - body: body, - conn: stream, - writer: writer, - parser: HttpResponseParser::default(), + let pl = Box::new(Pipeline { + body, conn, writer, + parser: Some(HttpResponseParser::default()), parser_buf: BytesMut::new(), disconnected: false, - running: RunningState::Running, drain: None, + decompress: None, + should_decompress: self.req.response_decompress(), + write_state: RunningState::Running, }); self.state = State::Send(pl); }, @@ -150,11 +151,13 @@ pub(crate) struct Pipeline { body: IoBody, conn: Connection, writer: HttpClientWriter, - parser: HttpResponseParser, + parser: Option, parser_buf: BytesMut, disconnected: bool, - running: RunningState, drain: Option>, + decompress: Option, + should_decompress: bool, + write_state: RunningState, } enum IoBody { @@ -163,7 +166,7 @@ enum IoBody { Done, } -#[derive(PartialEq)] +#[derive(Debug, PartialEq)] enum RunningState { Running, Paused, @@ -189,25 +192,89 @@ impl Pipeline { #[inline] pub fn parse(&mut self) -> Poll { - self.parser.parse(&mut self.conn, &mut self.parser_buf) + match self.parser.as_mut().unwrap().parse(&mut self.conn, &mut self.parser_buf) { + Ok(Async::Ready(resp)) => { + // check content-encoding + if self.should_decompress { + if let Some(enc) = resp.headers().get(CONTENT_ENCODING) { + if let Ok(enc) = enc.to_str() { + match ContentEncoding::from(enc) { + ContentEncoding::Auto | ContentEncoding::Identity => (), + enc => self.decompress = Some(PayloadStream::new(enc)), + } + } + } + } + + Ok(Async::Ready(resp)) + } + val => val, + } } #[inline] pub fn poll(&mut self) -> Poll, PayloadError> { - self.poll_write() - .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e).as_str()))?; - Ok(self.parser.parse_payload(&mut self.conn, &mut self.parser_buf)?) + let mut need_run = false; + + // need write? + if let Async::NotReady = self.poll_write() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))? + { + need_run = true; + } + + // need read? + if self.parser.is_some() { + loop { + match self.parser.as_mut().unwrap() + .parse_payload(&mut self.conn, &mut self.parser_buf)? + { + Async::Ready(Some(b)) => { + if let Some(ref mut decompress) = self.decompress { + match decompress.feed_data(b) { + Ok(Some(b)) => return Ok(Async::Ready(Some(b))), + Ok(None) => return Ok(Async::NotReady), + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => + continue, + Err(err) => return Err(err.into()), + } + } else { + return Ok(Async::Ready(Some(b))) + } + }, + Async::Ready(None) => { + let _ = self.parser.take(); + break + } + Async::NotReady => return Ok(Async::NotReady), + } + } + } + + // eof + if let Some(mut decompress) = self.decompress.take() { + let res = decompress.feed_eof(); + if let Some(b) = res? { + return Ok(Async::Ready(Some(b))) + } + } + + if need_run { + Ok(Async::NotReady) + } else { + Ok(Async::Ready(None)) + } } #[inline] pub fn poll_write(&mut self) -> Poll<(), Error> { - if self.running == RunningState::Done { + if self.write_state == RunningState::Done { return Ok(Async::Ready(())) } let mut done = false; - if self.drain.is_none() && self.running != RunningState::Paused { + if self.drain.is_none() && self.write_state != RunningState::Paused { 'outter: loop { let result = match mem::replace(&mut self.body, IoBody::Done) { IoBody::Payload(mut body) => { @@ -243,6 +310,7 @@ impl Pipeline { match frame { Frame::Chunk(None) => { // info.context = Some(ctx); + self.disconnected = true; self.writer.write_eof()?; break 'outter }, @@ -253,7 +321,7 @@ impl Pipeline { } self.body = IoBody::Actor(ctx); if self.drain.is_some() { - self.running.resume(); + self.write_state.resume(); break } res.unwrap() @@ -270,6 +338,7 @@ impl Pipeline { } }, IoBody::Done => { + self.disconnected = true; done = true; break } @@ -277,11 +346,11 @@ impl Pipeline { match result { WriterState::Pause => { - self.running.pause(); + self.write_state.pause(); break } WriterState::Done => { - self.running.resume() + self.write_state.resume() }, } } @@ -290,14 +359,18 @@ impl Pipeline { // flush io but only if we need to match self.writer.poll_completed(&mut self.conn, false) { Ok(Async::Ready(_)) => { - self.running.resume(); + if self.disconnected { + self.write_state = RunningState::Done; + } else { + self.write_state.resume(); + } // resolve drain futures if let Some(tx) = self.drain.take() { let _ = tx.send(()); } // restart io processing - if !done { + if !done || self.write_state == RunningState::Done { self.poll_write() } else { Ok(Async::NotReady) diff --git a/src/client/request.rs b/src/client/request.rs index 37d95fa74..42682a30c 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -4,7 +4,7 @@ use std::io::Write; use actix::{Addr, Unsync}; use cookie::{Cookie, CookieJar}; use bytes::{BytesMut, BufMut}; -use http::{HeaderMap, Method, Version, Uri, HttpTryFrom, Error as HttpError}; +use http::{uri, HeaderMap, Method, Version, Uri, HttpTryFrom, Error as HttpError}; use http::header::{self, HeaderName, HeaderValue}; use serde_json; use serde::Serialize; @@ -25,6 +25,16 @@ pub struct ClientRequest { chunked: bool, upgrade: bool, encoding: ContentEncoding, + response_decompress: bool, + buffer_capacity: Option<(usize, usize)>, + conn: ConnectionType, + +} + +enum ConnectionType { + Default, + Connector(Addr), + Connection(Connection), } impl Default for ClientRequest { @@ -39,6 +49,9 @@ impl Default for ClientRequest { chunked: false, upgrade: false, encoding: ContentEncoding::Auto, + response_decompress: true, + buffer_capacity: None, + conn: ConnectionType::Default, } } } @@ -89,6 +102,7 @@ impl ClientRequest { request: Some(ClientRequest::default()), err: None, cookies: None, + default_headers: true, } } @@ -158,6 +172,16 @@ impl ClientRequest { self.encoding } + /// Decompress response payload + #[inline] + pub fn response_decompress(&self) -> bool { + self.response_decompress + } + + pub fn buffer_capacity(&self) -> Option<(usize, usize)> { + self.buffer_capacity + } + /// Get body os this response #[inline] pub fn body(&self) -> &Body { @@ -175,18 +199,14 @@ impl ClientRequest { } /// Send request - pub fn send(self) -> SendRequest { - SendRequest::new(self) - } - - /// Send request using custom connector - pub fn with_connector(self, conn: Addr) -> SendRequest { - SendRequest::with_connector(self, conn) - } - - /// Send request using existing Connection - pub fn with_connection(self, conn: Connection) -> SendRequest { - SendRequest::with_connection(self, conn) + /// + /// This method returns future that resolves to a ClientResponse + pub fn send(mut self) -> SendRequest { + match mem::replace(&mut self.conn, ConnectionType::Default) { + ConnectionType::Default => SendRequest::new(self), + ConnectionType::Connector(conn) => SendRequest::with_connector(self, conn), + ConnectionType::Connection(conn) => SendRequest::with_connection(self, conn), + } } } @@ -216,6 +236,7 @@ pub struct ClientRequestBuilder { request: Option, err: Option, cookies: Option, + default_headers: bool, } impl ClientRequestBuilder { @@ -409,6 +430,48 @@ impl ClientRequestBuilder { self } + /// Do not add default request headers. + /// By default `Accept-Encoding` header is set. + pub fn no_default_headers(&mut self) -> &mut Self { + self.default_headers = false; + self + } + + /// Disable automatic decompress response body + pub fn disable_decompress(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.response_decompress = false; + } + self + } + + /// Set write buffer capacity + pub fn buffer_capacity(&mut self, + low_watermark: usize, + high_watermark: usize) -> &mut Self + { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.buffer_capacity = Some((low_watermark, high_watermark)); + } + self + } + + /// Send request using custom connector + pub fn with_connector(&mut self, conn: Addr) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.conn = ConnectionType::Connector(conn); + } + self + } + + /// Send request using existing Connection + pub fn with_connection(&mut self, conn: Connection) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.conn = ConnectionType::Connection(conn); + } + self + } + /// This method calls provided closure with builder reference if value is true. pub fn if_true(&mut self, value: bool, f: F) -> &mut Self where F: FnOnce(&mut ClientRequestBuilder) @@ -437,6 +500,23 @@ impl ClientRequestBuilder { return Err(e) } + if self.default_headers { + // enable br only for https + let https = + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.uri.scheme_part() + .map(|s| s == &uri::Scheme::HTTPS).unwrap_or(true) + } else { + true + }; + + if https { + self.header(header::ACCEPT_ENCODING, "br, gzip, deflate"); + } else { + self.header(header::ACCEPT_ENCODING, "gzip, deflate"); + } + } + let mut request = self.request.take().expect("cannot reuse request builder"); // set cookies @@ -482,6 +562,7 @@ impl ClientRequestBuilder { request: self.request.take(), err: self.err.take(), cookies: self.cookies.take(), + default_headers: self.default_headers, } } } diff --git a/src/client/response.rs b/src/client/response.rs index 4bb7c2d66..392c91332 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -1,20 +1,15 @@ use std::{fmt, str}; use std::rc::Rc; use std::cell::UnsafeCell; -use std::collections::HashMap; -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use cookie::Cookie; -use futures::{Async, Future, Poll, Stream}; +use futures::{Async, Poll, Stream}; use http::{HeaderMap, StatusCode, Version}; use http::header::{self, HeaderValue}; -use mime::Mime; -use serde_json; -use serde::de::DeserializeOwned; -use url::form_urlencoded; -// use multipart::Multipart; -use error::{CookieParseError, ParseError, PayloadError, JsonPayloadError, UrlencodedError}; +use httpmessage::HttpMessage; +use error::{CookieParseError, PayloadError}; use super::pipeline::Pipeline; @@ -41,6 +36,14 @@ impl Default for ClientMessage { /// An HTTP Client response pub struct ClientResponse(Rc>, Option>); +impl HttpMessage for ClientResponse { + /// Get the headers from the response. + #[inline] + fn headers(&self) -> &HeaderMap { + &self.as_ref().headers + } +} + impl ClientResponse { pub(crate) fn new(msg: ClientMessage) -> ClientResponse { @@ -68,30 +71,12 @@ impl ClientResponse { self.as_ref().version } - /// Get the headers from the response. - #[inline] - pub fn headers(&self) -> &HeaderMap { - &self.as_ref().headers - } - - /// Get a mutable reference to the headers. - #[inline] - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.as_mut().headers - } - /// Get the status from the server. #[inline] pub fn status(&self) -> StatusCode { self.as_ref().status } - /// Set the `StatusCode` for this response. - #[inline] - pub fn set_status(&mut self, status: StatusCode) { - self.as_mut().status = status - } - /// Load request cookies. pub fn cookies(&self) -> Result<&Vec>, CookieParseError> { if self.as_ref().cookies.is_none() { @@ -120,83 +105,6 @@ impl ClientResponse { } None } - - /// Read the request content type. If request does not contain - /// *Content-Type* header, empty str get returned. - pub fn content_type(&self) -> &str { - if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - return content_type.split(';').next().unwrap().trim() - } - } - "" - } - - /// Convert the request content type to a known mime type. - pub fn mime_type(&self) -> Option { - if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - return match content_type.parse() { - Ok(mt) => Some(mt), - Err(_) => None - }; - } - } - None - } - - /// Check if request has chunked transfer encoding - pub fn chunked(&self) -> Result { - if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { - if let Ok(s) = encodings.to_str() { - Ok(s.to_lowercase().contains("chunked")) - } else { - Err(ParseError::Header) - } - } else { - Ok(false) - } - } - - /// Load request body. - /// - /// By default only 256Kb payload reads to a memory, then connection get dropped - /// and `PayloadError` get returned. Use `ResponseBody::limit()` - /// method to change upper limit. - pub fn body(self) -> ResponseBody { - ResponseBody::new(self) - } - - // /// Return stream to http payload processes as multipart. - // /// - // /// Content-type: multipart/form-data; - // pub fn multipart(mut self) -> Multipart { - // Multipart::from_response(&mut self) - // } - - /// Parse `application/x-www-form-urlencoded` encoded body. - /// Return `UrlEncoded` future. It resolves to a `HashMap` which - /// contains decoded parameters. - /// - /// Returns error: - /// - /// * content type is not `application/x-www-form-urlencoded` - /// * transfer encoding is `chunked`. - /// * content-length is greater than 256k - pub fn urlencoded(self) -> UrlEncoded { - UrlEncoded::new(self) - } - - /// Parse `application/json` encoded body. - /// Return `JsonResponse` future. It resolves to a `T` value. - /// - /// Returns error: - /// - /// * content type is not `application/json` - /// * content length is greater than 256k - pub fn json(self) -> JsonResponse { - JsonResponse::from_response(self) - } } impl fmt::Debug for ClientResponse { @@ -229,230 +137,3 @@ impl Stream for ClientResponse { } } } - -/// Future that resolves to a complete response body. -#[must_use = "ResponseBody does nothing unless polled"] -pub struct ResponseBody { - limit: usize, - resp: Option, - fut: Option>>, -} - -impl ResponseBody { - - /// Create `ResponseBody` for request. - pub fn new(resp: ClientResponse) -> Self { - ResponseBody { - limit: 262_144, - resp: Some(resp), - fut: None, - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } -} - -impl Future for ResponseBody { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll { - if let Some(resp) = self.resp.take() { - if let Some(len) = resp.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > self.limit { - return Err(PayloadError::Overflow); - } - } else { - return Err(PayloadError::Overflow); - } - } - } - let limit = self.limit; - let fut = resp.from_err() - .fold(BytesMut::new(), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(PayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .map(|bytes| bytes.freeze()); - self.fut = Some(Box::new(fut)); - } - - self.fut.as_mut().expect("ResponseBody could not be used second time").poll() - } -} - -/// Client response json parser that resolves to a deserialized `T` value. -/// -/// Returns error: -/// -/// * content type is not `application/json` -/// * content length is greater than 256k -#[must_use = "JsonResponse does nothing unless polled"] -pub struct JsonResponse{ - limit: usize, - ct: &'static str, - resp: Option, - fut: Option>>, -} - -impl JsonResponse { - - /// Create `JsonResponse` for request. - pub fn from_response(resp: ClientResponse) -> Self { - JsonResponse{ - limit: 262_144, - resp: Some(resp), - ct: "application/json", - fut: None, - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } - - /// Set allowed content type. - /// - /// By default *application/json* content type is used. Set content type - /// to empty string if you want to disable content type check. - pub fn content_type(mut self, ct: &'static str) -> Self { - self.ct = ct; - self - } -} - -impl Future for JsonResponse { - type Item = T; - type Error = JsonPayloadError; - - fn poll(&mut self) -> Poll { - if let Some(resp) = self.resp.take() { - if let Some(len) = resp.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > self.limit { - return Err(JsonPayloadError::Overflow); - } - } else { - return Err(JsonPayloadError::Overflow); - } - } - } - // check content-type - if !self.ct.is_empty() && resp.content_type() != self.ct { - return Err(JsonPayloadError::ContentType) - } - - let limit = self.limit; - let fut = resp.from_err() - .fold(BytesMut::new(), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(JsonPayloadError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .and_then(|body| Ok(serde_json::from_slice::(&body)?)); - - self.fut = Some(Box::new(fut)); - } - - self.fut.as_mut().expect("JsonResponse could not be used second time").poll() - } -} - -/// Future that resolves to a parsed urlencoded values. -#[must_use = "UrlEncoded does nothing unless polled"] -pub struct UrlEncoded { - resp: Option, - limit: usize, - fut: Option, Error=UrlencodedError>>>, -} - -impl UrlEncoded { - pub fn new(resp: ClientResponse) -> UrlEncoded { - UrlEncoded{resp: Some(resp), - limit: 262_144, - fut: None} - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } -} - -impl Future for UrlEncoded { - type Item = HashMap; - type Error = UrlencodedError; - - fn poll(&mut self) -> Poll { - if let Some(resp) = self.resp.take() { - if resp.chunked().unwrap_or(false) { - return Err(UrlencodedError::Chunked) - } else if let Some(len) = resp.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > 262_144 { - return Err(UrlencodedError::Overflow); - } - } else { - return Err(UrlencodedError::UnknownLength); - } - } else { - return Err(UrlencodedError::UnknownLength); - } - } - - // check content type - let mut encoding = false; - if let Some(content_type) = resp.headers().get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if content_type.to_lowercase() == "application/x-www-form-urlencoded" { - encoding = true; - } - } - } - if !encoding { - return Err(UrlencodedError::ContentType); - } - - // urlencoded body - let limit = self.limit; - let fut = resp.from_err() - .fold(BytesMut::new(), move |mut body, chunk| { - if (body.len() + chunk.len()) > limit { - Err(UrlencodedError::Overflow) - } else { - body.extend_from_slice(&chunk); - Ok(body) - } - }) - .and_then(|body| { - let mut m = HashMap::new(); - for (k, v) in form_urlencoded::parse(&body) { - m.insert(k.into(), v.into()); - } - Ok(m) - }); - - self.fut = Some(Box::new(fut)); - } - - self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() - } -} diff --git a/src/client/writer.rs b/src/client/writer.rs index ad1bb6a13..f67bd7261 100644 --- a/src/client/writer.rs +++ b/src/client/writer.rs @@ -1,4 +1,5 @@ -#![allow(dead_code)] +#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] + use std::io::{self, Write}; use std::cell::RefCell; use std::fmt::Write as FmtWrite; @@ -48,14 +49,14 @@ pub(crate) struct HttpClientWriter { impl HttpClientWriter { - pub fn new(buf: SharedBytes) -> HttpClientWriter { - let encoder = ContentEncoder::Identity(TransferEncoding::eof(buf.clone())); + pub fn new(buffer: SharedBytes) -> HttpClientWriter { + let encoder = ContentEncoder::Identity(TransferEncoding::eof(buffer.clone())); HttpClientWriter { flags: Flags::empty(), written: 0, headers_size: 0, - buffer: buf, - encoder: encoder, + buffer, + encoder, low: LOW_WATERMARK, high: HIGH_WATERMARK, } @@ -65,9 +66,9 @@ impl HttpClientWriter { self.buffer.take(); } - pub fn keepalive(&self) -> bool { - self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) - } + // pub fn keepalive(&self) -> bool { + // self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) + // } /// Set write buffer capacity pub fn set_buffer_capacity(&mut self, low_watermark: usize, high_watermark: usize) { @@ -105,6 +106,9 @@ impl HttpClientWriter { // prepare task self.flags.insert(Flags::STARTED); self.encoder = content_encoder(self.buffer.clone(), msg); + if let Some(capacity) = msg.buffer_capacity() { + self.set_buffer_capacity(capacity.0, capacity.1); + } // render message { diff --git a/src/context.rs b/src/context.rs index a3e168f6d..aa6f4c49a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -164,6 +164,7 @@ impl HttpContext where A: Actor { self.stream = Some(SmallVec::new()); } self.stream.as_mut().map(|s| s.push(frame)); + self.inner.modify(); } /// Handle of the running future @@ -230,10 +231,7 @@ pub struct Drain { impl Drain { pub fn new(fut: oneshot::Receiver<()>) -> Self { - Drain { - fut: fut, - _a: PhantomData - } + Drain { fut, _a: PhantomData } } } diff --git a/src/error.rs b/src/error.rs index 513c0f4d0..6abbf7a0f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,7 +24,7 @@ use body::Body; use handler::Responder; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use httpcodes::{self, HTTPBadRequest, HTTPMethodNotAllowed, HTTPExpectationFailed}; +use httpcodes::{self, HttpExpectationFailed}; /// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) /// for actix web operations @@ -90,14 +90,13 @@ impl From for Error { } else { None }; - Error { cause: Box::new(err), backtrace: backtrace } + Error { cause: Box::new(err), backtrace } } } /// Compatibility for `failure::Error` impl ResponseError for failure::Compat - where T: fmt::Display + fmt::Debug + Sync + Send + 'static -{ } + where T: fmt::Display + fmt::Debug + Sync + Send + 'static { } impl From for Error { fn from(err: failure::Error) -> Error { @@ -293,6 +292,9 @@ pub enum MultipartError { /// Multipart boundary is not found #[fail(display="Multipart boundary is not found")] Boundary, + /// Multipart stream is incomplete + #[fail(display="Multipart stream is incomplete")] + Incomplete, /// Error during field parsing #[fail(display="{}", _0)] Parse(#[cause] ParseError), @@ -333,57 +335,26 @@ pub enum ExpectError { } impl ResponseError for ExpectError { - fn error_response(&self) -> HttpResponse { - HTTPExpectationFailed.with_body("Unknown Expect") + HttpExpectationFailed.with_body("Unknown Expect") } } -/// Websocket handshake errors +/// A set of error that can occure during parsing content type #[derive(Fail, PartialEq, Debug)] -pub enum WsHandshakeError { - /// Only get method is allowed - #[fail(display="Method not allowed")] - GetMethodRequired, - /// Upgrade header if not set to websocket - #[fail(display="Websocket upgrade is expected")] - NoWebsocketUpgrade, - /// Connection header is not set to upgrade - #[fail(display="Connection upgrade is expected")] - NoConnectionUpgrade, - /// Websocket version header is not set - #[fail(display="Websocket version header is required")] - NoVersionHeader, - /// Unsupported websocket version - #[fail(display="Unsupported version")] - UnsupportedVersion, - /// Websocket key is not set or wrong - #[fail(display="Unknown websocket key")] - BadWebsocketKey, +pub enum ContentTypeError { + /// Can not parse content type + #[fail(display="Can not parse content type")] + ParseError, + /// Unknown content encoding + #[fail(display="Unknown content encoding")] + UnknownEncoding, } -impl ResponseError for WsHandshakeError { - +/// Return `BadRequest` for `ContentTypeError` +impl ResponseError for ContentTypeError { fn error_response(&self) -> HttpResponse { - match *self { - WsHandshakeError::GetMethodRequired => { - HTTPMethodNotAllowed - .build() - .header(header::ALLOW, "GET") - .finish() - .unwrap() - } - WsHandshakeError::NoWebsocketUpgrade => - HTTPBadRequest.with_reason("No WebSocket UPGRADE header found"), - WsHandshakeError::NoConnectionUpgrade => - HTTPBadRequest.with_reason("No CONNECTION upgrade"), - WsHandshakeError::NoVersionHeader => - HTTPBadRequest.with_reason("Websocket version header is required"), - WsHandshakeError::UnsupportedVersion => - HTTPBadRequest.with_reason("Unsupported version"), - WsHandshakeError::BadWebsocketKey => - HTTPBadRequest.with_reason("Handshake error"), - } + HttpResponse::new(StatusCode::BAD_REQUEST, Body::Empty) } } @@ -402,6 +373,9 @@ pub enum UrlencodedError { /// Content type error #[fail(display="Content type error")] ContentType, + /// Parse error + #[fail(display="Parse error")] + Parse, /// Payload error #[fail(display="Error that occur during reading payload: {}", _0)] Payload(#[cause] PayloadError), @@ -412,9 +386,9 @@ impl ResponseError for UrlencodedError { fn error_response(&self) -> HttpResponse { match *self { - UrlencodedError::Overflow => httpcodes::HTTPPayloadTooLarge.into(), - UrlencodedError::UnknownLength => httpcodes::HTTPLengthRequired.into(), - _ => httpcodes::HTTPBadRequest.into(), + UrlencodedError::Overflow => httpcodes::HttpPayloadTooLarge.into(), + UrlencodedError::UnknownLength => httpcodes::HttpLengthRequired.into(), + _ => httpcodes::HttpBadRequest.into(), } } } @@ -447,8 +421,8 @@ impl ResponseError for JsonPayloadError { fn error_response(&self) -> HttpResponse { match *self { - JsonPayloadError::Overflow => httpcodes::HTTPPayloadTooLarge.into(), - _ => httpcodes::HTTPBadRequest.into(), + JsonPayloadError::Overflow => httpcodes::HttpPayloadTooLarge.into(), + _ => httpcodes::HttpBadRequest.into(), } } } @@ -536,10 +510,10 @@ unsafe impl Sync for InternalError {} unsafe impl Send for InternalError {} impl InternalError { - pub fn new(err: T, status: StatusCode) -> Self { + pub fn new(cause: T, status: StatusCode) -> Self { InternalError { - cause: err, - status: status, + cause, + status, backtrace: Backtrace::new(), } } @@ -739,22 +713,6 @@ mod tests { assert_eq!(resp.status(), StatusCode::EXPECTATION_FAILED); } - #[test] - fn test_wserror_http_response() { - let resp: HttpResponse = WsHandshakeError::GetMethodRequired.error_response(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - let resp: HttpResponse = WsHandshakeError::NoWebsocketUpgrade.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::NoConnectionUpgrade.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::NoVersionHeader.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::UnsupportedVersion.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WsHandshakeError::BadWebsocketKey.error_response(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } - macro_rules! from { ($from:expr => $error:pat) => { match ParseError::from($from) { diff --git a/src/fs.rs b/src/fs.rs index 7282e3ed4..88d52a2f4 100644 --- a/src/fs.rs +++ b/src/fs.rs @@ -143,10 +143,7 @@ pub struct Directory{ impl Directory { pub fn new(base: PathBuf, path: PathBuf) -> Directory { - Directory { - base: base, - path: path - } + Directory { base, path } } fn can_list(&self, entry: &io::Result) -> bool { @@ -205,7 +202,7 @@ impl Responder for Directory {
    \ {}\
\n", index_of, index_of, body); - Ok(HTTPOk.build() + Ok(HttpOk.build() .content_type("text/html; charset=utf-8") .body(html).unwrap()) } @@ -330,7 +327,7 @@ impl Handler for StaticFiles { } new_path.push_str(redir_index); Ok(FilesystemElement::Redirect( - HTTPFound + HttpFound .build() .header::<_, &str>("LOCATION", &new_path) .finish().unwrap())) diff --git a/src/handler.rs b/src/handler.rs index 857c95398..4aa5ec5b4 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -215,7 +215,7 @@ impl WrapHandler S: 'static, { pub fn new(h: H) -> Self { - WrapHandler{h: h, s: PhantomData} + WrapHandler{h, s: PhantomData} } } @@ -225,7 +225,7 @@ impl RouteHandler for WrapHandler S: 'static, { fn handle(&mut self, req: HttpRequest) -> Reply { - let req2 = req.clone_without_state(); + let req2 = req.without_state(); match self.h.handle(req).respond_to(req2) { Ok(reply) => reply.into(), Err(err) => Reply::response(err.into()), @@ -266,7 +266,7 @@ impl RouteHandler for AsyncHandler S: 'static, { fn handle(&mut self, req: HttpRequest) -> Reply { - let req2 = req.clone_without_state(); + let req2 = req.without_state(); let fut = (self.h)(req) .map_err(|e| e.into()) .then(move |r| { @@ -309,7 +309,7 @@ impl RouteHandler for AsyncHandler /// # use actix_web::*; /// # /// # fn index(req: HttpRequest) -> httpcodes::StaticResponse { -/// # httpcodes::HTTPOk +/// # httpcodes::HttpOk /// # } /// fn main() { /// let app = Application::new() @@ -345,10 +345,10 @@ impl NormalizePath { /// Create new `NormalizePath` instance pub fn new(append: bool, merge: bool, redirect: StatusCode) -> NormalizePath { NormalizePath { - append: append, - merge: merge, + append, + merge, + redirect, re_merge: Regex::new("//+").unwrap(), - redirect: redirect, not_found: StatusCode::NOT_FOUND, } } @@ -395,7 +395,6 @@ impl Handler for NormalizePath { } } } else if p.ends_with('/') { - println!("=== {:?}", p); // try to remove trailing slash let p = p.as_ref().trim_right_matches('/'); if router.has_route(p) { diff --git a/src/helpers.rs b/src/helpers.rs index 25e22b8fe..5f54f48f9 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -8,7 +8,7 @@ use time; use bytes::{BufMut, BytesMut}; use http::Version; -use httprequest::HttpMessage; +use httprequest::HttpInnerMessage; // "Sun, 06 Nov 1994 08:49:37 GMT".len() pub(crate) const DATE_VALUE_LENGTH: usize = 29; @@ -67,7 +67,7 @@ impl fmt::Write for CachedDate { } /// Internal use only! unsafe -pub(crate) struct SharedMessagePool(RefCell>>); +pub(crate) struct SharedMessagePool(RefCell>>); impl SharedMessagePool { pub fn new() -> SharedMessagePool { @@ -75,16 +75,16 @@ impl SharedMessagePool { } #[inline] - pub fn get(&self) -> Rc { + pub fn get(&self) -> Rc { if let Some(msg) = self.0.borrow_mut().pop_front() { msg } else { - Rc::new(HttpMessage::default()) + Rc::new(HttpInnerMessage::default()) } } #[inline] - pub fn release(&self, mut msg: Rc) { + pub fn release(&self, mut msg: Rc) { let v = &mut self.0.borrow_mut(); if v.len() < 128 { Rc::get_mut(&mut msg).unwrap().reset(); @@ -93,10 +93,10 @@ impl SharedMessagePool { } } -pub(crate) struct SharedHttpMessage( - Option>, Option>); +pub(crate) struct SharedHttpInnerMessage( + Option>, Option>); -impl Drop for SharedHttpMessage { +impl Drop for SharedHttpInnerMessage { fn drop(&mut self) { if let Some(ref pool) = self.1 { if let Some(msg) = self.0.take() { @@ -108,56 +108,56 @@ impl Drop for SharedHttpMessage { } } -impl Deref for SharedHttpMessage { - type Target = HttpMessage; +impl Deref for SharedHttpInnerMessage { + type Target = HttpInnerMessage; - fn deref(&self) -> &HttpMessage { + fn deref(&self) -> &HttpInnerMessage { self.get_ref() } } -impl DerefMut for SharedHttpMessage { +impl DerefMut for SharedHttpInnerMessage { - fn deref_mut(&mut self) -> &mut HttpMessage { + fn deref_mut(&mut self) -> &mut HttpInnerMessage { self.get_mut() } } -impl Clone for SharedHttpMessage { +impl Clone for SharedHttpInnerMessage { - fn clone(&self) -> SharedHttpMessage { - SharedHttpMessage(self.0.clone(), self.1.clone()) + fn clone(&self) -> SharedHttpInnerMessage { + SharedHttpInnerMessage(self.0.clone(), self.1.clone()) } } -impl Default for SharedHttpMessage { +impl Default for SharedHttpInnerMessage { - fn default() -> SharedHttpMessage { - SharedHttpMessage(Some(Rc::new(HttpMessage::default())), None) + fn default() -> SharedHttpInnerMessage { + SharedHttpInnerMessage(Some(Rc::new(HttpInnerMessage::default())), None) } } -impl SharedHttpMessage { +impl SharedHttpInnerMessage { - pub fn from_message(msg: HttpMessage) -> SharedHttpMessage { - SharedHttpMessage(Some(Rc::new(msg)), None) + pub fn from_message(msg: HttpInnerMessage) -> SharedHttpInnerMessage { + SharedHttpInnerMessage(Some(Rc::new(msg)), None) } - pub fn new(msg: Rc, pool: Rc) -> SharedHttpMessage { - SharedHttpMessage(Some(msg), Some(pool)) + pub fn new(msg: Rc, pool: Rc) -> SharedHttpInnerMessage { + SharedHttpInnerMessage(Some(msg), Some(pool)) } #[inline(always)] #[allow(mutable_transmutes)] #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] - pub fn get_mut(&self) -> &mut HttpMessage { - let r: &HttpMessage = self.0.as_ref().unwrap().as_ref(); + pub fn get_mut(&self) -> &mut HttpInnerMessage { + let r: &HttpInnerMessage = self.0.as_ref().unwrap().as_ref(); unsafe{mem::transmute(r)} } #[inline(always)] #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - pub fn get_ref(&self) -> &HttpMessage { + pub fn get_ref(&self) -> &HttpInnerMessage { self.0.as_ref().unwrap() } } diff --git a/src/httpcodes.rs b/src/httpcodes.rs index 49ca8ec63..326df2613 100644 --- a/src/httpcodes.rs +++ b/src/httpcodes.rs @@ -7,67 +7,174 @@ use handler::{Reply, Handler, RouteHandler, Responder}; use httprequest::HttpRequest; use httpresponse::{HttpResponse, HttpResponseBuilder}; +pub const HttpOk: StaticResponse = StaticResponse(StatusCode::OK); +pub const HttpCreated: StaticResponse = StaticResponse(StatusCode::CREATED); +pub const HttpAccepted: StaticResponse = StaticResponse(StatusCode::ACCEPTED); +pub const HttpNonAuthoritativeInformation: StaticResponse = + StaticResponse(StatusCode::NON_AUTHORITATIVE_INFORMATION); +pub const HttpNoContent: StaticResponse = StaticResponse(StatusCode::NO_CONTENT); +pub const HttpResetContent: StaticResponse = StaticResponse(StatusCode::RESET_CONTENT); +pub const HttpPartialContent: StaticResponse = StaticResponse(StatusCode::PARTIAL_CONTENT); +pub const HttpMultiStatus: StaticResponse = StaticResponse(StatusCode::MULTI_STATUS); +pub const HttpAlreadyReported: StaticResponse = StaticResponse(StatusCode::ALREADY_REPORTED); + +pub const HttpMultipleChoices: StaticResponse = StaticResponse(StatusCode::MULTIPLE_CHOICES); +pub const HttpMovedPermanenty: StaticResponse = StaticResponse(StatusCode::MOVED_PERMANENTLY); +pub const HttpFound: StaticResponse = StaticResponse(StatusCode::FOUND); +pub const HttpSeeOther: StaticResponse = StaticResponse(StatusCode::SEE_OTHER); +pub const HttpNotModified: StaticResponse = StaticResponse(StatusCode::NOT_MODIFIED); +pub const HttpUseProxy: StaticResponse = StaticResponse(StatusCode::USE_PROXY); +pub const HttpTemporaryRedirect: StaticResponse = + StaticResponse(StatusCode::TEMPORARY_REDIRECT); +pub const HttpPermanentRedirect: StaticResponse = + StaticResponse(StatusCode::PERMANENT_REDIRECT); + +pub const HttpBadRequest: StaticResponse = StaticResponse(StatusCode::BAD_REQUEST); +pub const HttpUnauthorized: StaticResponse = StaticResponse(StatusCode::UNAUTHORIZED); +pub const HttpPaymentRequired: StaticResponse = StaticResponse(StatusCode::PAYMENT_REQUIRED); +pub const HttpForbidden: StaticResponse = StaticResponse(StatusCode::FORBIDDEN); +pub const HttpNotFound: StaticResponse = StaticResponse(StatusCode::NOT_FOUND); +pub const HttpMethodNotAllowed: StaticResponse = + StaticResponse(StatusCode::METHOD_NOT_ALLOWED); +pub const HttpNotAcceptable: StaticResponse = StaticResponse(StatusCode::NOT_ACCEPTABLE); +pub const HttpProxyAuthenticationRequired: StaticResponse = + StaticResponse(StatusCode::PROXY_AUTHENTICATION_REQUIRED); +pub const HttpRequestTimeout: StaticResponse = StaticResponse(StatusCode::REQUEST_TIMEOUT); +pub const HttpConflict: StaticResponse = StaticResponse(StatusCode::CONFLICT); +pub const HttpGone: StaticResponse = StaticResponse(StatusCode::GONE); +pub const HttpLengthRequired: StaticResponse = StaticResponse(StatusCode::LENGTH_REQUIRED); +pub const HttpPreconditionFailed: StaticResponse = + StaticResponse(StatusCode::PRECONDITION_FAILED); +pub const HttpPayloadTooLarge: StaticResponse = StaticResponse(StatusCode::PAYLOAD_TOO_LARGE); +pub const HttpUriTooLong: StaticResponse = StaticResponse(StatusCode::URI_TOO_LONG); +pub const HttpUnsupportedMediaType: StaticResponse = + StaticResponse(StatusCode::UNSUPPORTED_MEDIA_TYPE); +pub const HttpRangeNotSatisfiable: StaticResponse = + StaticResponse(StatusCode::RANGE_NOT_SATISFIABLE); +pub const HttpExpectationFailed: StaticResponse = + StaticResponse(StatusCode::EXPECTATION_FAILED); + +pub const HttpInternalServerError: StaticResponse = + StaticResponse(StatusCode::INTERNAL_SERVER_ERROR); +pub const HttpNotImplemented: StaticResponse = StaticResponse(StatusCode::NOT_IMPLEMENTED); +pub const HttpBadGateway: StaticResponse = StaticResponse(StatusCode::BAD_GATEWAY); +pub const HttpServiceUnavailable: StaticResponse = + StaticResponse(StatusCode::SERVICE_UNAVAILABLE); +pub const HttpGatewayTimeout: StaticResponse = + StaticResponse(StatusCode::GATEWAY_TIMEOUT); +pub const HttpVersionNotSupported: StaticResponse = + StaticResponse(StatusCode::HTTP_VERSION_NOT_SUPPORTED); +pub const HttpVariantAlsoNegotiates: StaticResponse = + StaticResponse(StatusCode::VARIANT_ALSO_NEGOTIATES); +pub const HttpInsufficientStorage: StaticResponse = + StaticResponse(StatusCode::INSUFFICIENT_STORAGE); +pub const HttpLoopDetected: StaticResponse = StaticResponse(StatusCode::LOOP_DETECTED); + +#[doc(hidden)] pub const HTTPOk: StaticResponse = StaticResponse(StatusCode::OK); +#[doc(hidden)] pub const HTTPCreated: StaticResponse = StaticResponse(StatusCode::CREATED); +#[doc(hidden)] pub const HTTPAccepted: StaticResponse = StaticResponse(StatusCode::ACCEPTED); +#[doc(hidden)] pub const HTTPNonAuthoritativeInformation: StaticResponse = StaticResponse(StatusCode::NON_AUTHORITATIVE_INFORMATION); +#[doc(hidden)] pub const HTTPNoContent: StaticResponse = StaticResponse(StatusCode::NO_CONTENT); +#[doc(hidden)] pub const HTTPResetContent: StaticResponse = StaticResponse(StatusCode::RESET_CONTENT); +#[doc(hidden)] pub const HTTPPartialContent: StaticResponse = StaticResponse(StatusCode::PARTIAL_CONTENT); +#[doc(hidden)] pub const HTTPMultiStatus: StaticResponse = StaticResponse(StatusCode::MULTI_STATUS); +#[doc(hidden)] pub const HTTPAlreadyReported: StaticResponse = StaticResponse(StatusCode::ALREADY_REPORTED); +#[doc(hidden)] pub const HTTPMultipleChoices: StaticResponse = StaticResponse(StatusCode::MULTIPLE_CHOICES); +#[doc(hidden)] pub const HTTPMovedPermanenty: StaticResponse = StaticResponse(StatusCode::MOVED_PERMANENTLY); +#[doc(hidden)] pub const HTTPFound: StaticResponse = StaticResponse(StatusCode::FOUND); +#[doc(hidden)] pub const HTTPSeeOther: StaticResponse = StaticResponse(StatusCode::SEE_OTHER); +#[doc(hidden)] pub const HTTPNotModified: StaticResponse = StaticResponse(StatusCode::NOT_MODIFIED); +#[doc(hidden)] pub const HTTPUseProxy: StaticResponse = StaticResponse(StatusCode::USE_PROXY); +#[doc(hidden)] pub const HTTPTemporaryRedirect: StaticResponse = StaticResponse(StatusCode::TEMPORARY_REDIRECT); +#[doc(hidden)] pub const HTTPPermanentRedirect: StaticResponse = StaticResponse(StatusCode::PERMANENT_REDIRECT); +#[doc(hidden)] pub const HTTPBadRequest: StaticResponse = StaticResponse(StatusCode::BAD_REQUEST); +#[doc(hidden)] pub const HTTPUnauthorized: StaticResponse = StaticResponse(StatusCode::UNAUTHORIZED); +#[doc(hidden)] pub const HTTPPaymentRequired: StaticResponse = StaticResponse(StatusCode::PAYMENT_REQUIRED); +#[doc(hidden)] pub const HTTPForbidden: StaticResponse = StaticResponse(StatusCode::FORBIDDEN); +#[doc(hidden)] pub const HTTPNotFound: StaticResponse = StaticResponse(StatusCode::NOT_FOUND); +#[doc(hidden)] pub const HTTPMethodNotAllowed: StaticResponse = StaticResponse(StatusCode::METHOD_NOT_ALLOWED); +#[doc(hidden)] pub const HTTPNotAcceptable: StaticResponse = StaticResponse(StatusCode::NOT_ACCEPTABLE); +#[doc(hidden)] pub const HTTPProxyAuthenticationRequired: StaticResponse = StaticResponse(StatusCode::PROXY_AUTHENTICATION_REQUIRED); +#[doc(hidden)] pub const HTTPRequestTimeout: StaticResponse = StaticResponse(StatusCode::REQUEST_TIMEOUT); +#[doc(hidden)] pub const HTTPConflict: StaticResponse = StaticResponse(StatusCode::CONFLICT); +#[doc(hidden)] pub const HTTPGone: StaticResponse = StaticResponse(StatusCode::GONE); +#[doc(hidden)] pub const HTTPLengthRequired: StaticResponse = StaticResponse(StatusCode::LENGTH_REQUIRED); +#[doc(hidden)] pub const HTTPPreconditionFailed: StaticResponse = StaticResponse(StatusCode::PRECONDITION_FAILED); +#[doc(hidden)] pub const HTTPPayloadTooLarge: StaticResponse = StaticResponse(StatusCode::PAYLOAD_TOO_LARGE); +#[doc(hidden)] pub const HTTPUriTooLong: StaticResponse = StaticResponse(StatusCode::URI_TOO_LONG); +#[doc(hidden)] pub const HTTPUnsupportedMediaType: StaticResponse = StaticResponse(StatusCode::UNSUPPORTED_MEDIA_TYPE); +#[doc(hidden)] pub const HTTPRangeNotSatisfiable: StaticResponse = StaticResponse(StatusCode::RANGE_NOT_SATISFIABLE); +#[doc(hidden)] pub const HTTPExpectationFailed: StaticResponse = StaticResponse(StatusCode::EXPECTATION_FAILED); +#[doc(hidden)] pub const HTTPInternalServerError: StaticResponse = StaticResponse(StatusCode::INTERNAL_SERVER_ERROR); +#[doc(hidden)] pub const HTTPNotImplemented: StaticResponse = StaticResponse(StatusCode::NOT_IMPLEMENTED); +#[doc(hidden)] pub const HTTPBadGateway: StaticResponse = StaticResponse(StatusCode::BAD_GATEWAY); +#[doc(hidden)] pub const HTTPServiceUnavailable: StaticResponse = StaticResponse(StatusCode::SERVICE_UNAVAILABLE); +#[doc(hidden)] pub const HTTPGatewayTimeout: StaticResponse = StaticResponse(StatusCode::GATEWAY_TIMEOUT); +#[doc(hidden)] pub const HTTPVersionNotSupported: StaticResponse = StaticResponse(StatusCode::HTTP_VERSION_NOT_SUPPORTED); +#[doc(hidden)] pub const HTTPVariantAlsoNegotiates: StaticResponse = StaticResponse(StatusCode::VARIANT_ALSO_NEGOTIATES); +#[doc(hidden)] pub const HTTPInsufficientStorage: StaticResponse = StaticResponse(StatusCode::INSUFFICIENT_STORAGE); +#[doc(hidden)] pub const HTTPLoopDetected: StaticResponse = StaticResponse(StatusCode::LOOP_DETECTED); diff --git a/src/httpmessage.rs b/src/httpmessage.rs new file mode 100644 index 000000000..60132136c --- /dev/null +++ b/src/httpmessage.rs @@ -0,0 +1,596 @@ +use std::str; +use std::collections::HashMap; +use bytes::{Bytes, BytesMut}; +use futures::{Future, Stream, Poll}; +use http_range::HttpRange; +use serde::de::DeserializeOwned; +use mime::Mime; +use url::form_urlencoded; +use encoding::all::UTF_8; +use encoding::EncodingRef; +use encoding::label::encoding_from_whatwg_label; +use http::{header, HeaderMap}; + +use json::JsonBody; +use multipart::Multipart; +use error::{ParseError, ContentTypeError, + HttpRangeError, PayloadError, UrlencodedError}; + + +/// Trait that implements general purpose operations on http messages +pub trait HttpMessage { + + /// Read the message headers. + fn headers(&self) -> &HeaderMap; + + /// Read the request content type. If request does not contain + /// *Content-Type* header, empty str get returned. + fn content_type(&self) -> &str { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return content_type.split(';').next().unwrap().trim() + } + } + "" + } + + /// Get content type encoding + /// + /// UTF-8 is used by default, If request charset is not set. + fn encoding(&self) -> Result { + if let Some(mime_type) = self.mime_type()? { + if let Some(charset) = mime_type.get_param("charset") { + if let Some(enc) = encoding_from_whatwg_label(charset.as_str()) { + Ok(enc) + } else { + Err(ContentTypeError::UnknownEncoding) + } + } else { + Ok(UTF_8) + } + } else { + Ok(UTF_8) + } + } + + /// Convert the request content type to a known mime type. + fn mime_type(&self) -> Result, ContentTypeError> { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return match content_type.parse() { + Ok(mt) => Ok(Some(mt)), + Err(_) => Err(ContentTypeError::ParseError), + }; + } else { + return Err(ContentTypeError::ParseError) + } + } + Ok(None) + } + + /// Check if request has chunked transfer encoding + fn chunked(&self) -> Result { + if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { + if let Ok(s) = encodings.to_str() { + Ok(s.to_lowercase().contains("chunked")) + } else { + Err(ParseError::Header) + } + } else { + Ok(false) + } + } + + /// Parses Range HTTP header string as per RFC 2616. + /// `size` is full size of response (file). + fn range(&self, size: u64) -> Result, HttpRangeError> { + if let Some(range) = self.headers().get(header::RANGE) { + HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size) + .map_err(|e| e.into()) + } else { + Ok(Vec::new()) + } + } + + /// Load http message body. + /// + /// By default only 256Kb payload reads to a memory, then `PayloadError::Overflow` + /// get returned. Use `MessageBody::limit()` method to change upper limit. + /// + /// ## Server example + /// + /// ```rust + /// # extern crate bytes; + /// # extern crate actix_web; + /// # extern crate futures; + /// # #[macro_use] extern crate serde_derive; + /// use actix_web::*; + /// use bytes::Bytes; + /// use futures::future::Future; + /// + /// fn index(mut req: HttpRequest) -> Box> { + /// req.body() // <- get Body future + /// .limit(1024) // <- change max size of the body to a 1kb + /// .from_err() + /// .and_then(|bytes: Bytes| { // <- complete body + /// println!("==== BODY ==== {:?}", bytes); + /// Ok(httpcodes::HttpOk.into()) + /// }).responder() + /// } + /// # fn main() {} + /// ``` + fn body(self) -> MessageBody + where Self: Stream + Sized + { + MessageBody::new(self) + } + + /// Parse `application/x-www-form-urlencoded` encoded body. + /// Return `UrlEncoded` future. It resolves to a `HashMap` which + /// contains decoded parameters. + /// + /// Returns error: + /// + /// * content type is not `application/x-www-form-urlencoded` + /// * transfer encoding is `chunked`. + /// * content-length is greater than 256k + /// + /// ## Server example + /// + /// ```rust + /// # extern crate actix_web; + /// # extern crate futures; + /// use actix_web::*; + /// use futures::future::{Future, ok}; + /// + /// fn index(mut req: HttpRequest) -> Box> { + /// req.urlencoded() // <- get UrlEncoded future + /// .from_err() + /// .and_then(|params| { // <- url encoded parameters + /// println!("==== BODY ==== {:?}", params); + /// ok(httpcodes::HttpOk.into()) + /// }) + /// .responder() + /// } + /// # fn main() {} + /// ``` + fn urlencoded(self) -> UrlEncoded + where Self: Stream + Sized + { + UrlEncoded::new(self) + } + + /// Parse `application/json` encoded body. + /// Return `JsonBody` future. It resolves to a `T` value. + /// + /// Returns error: + /// + /// * content type is not `application/json` + /// * content length is greater than 256k + /// + /// ## Server example + /// + /// ```rust + /// # extern crate actix_web; + /// # extern crate futures; + /// # #[macro_use] extern crate serde_derive; + /// use actix_web::*; + /// use futures::future::{Future, ok}; + /// + /// #[derive(Deserialize, Debug)] + /// struct MyObj { + /// name: String, + /// } + /// + /// fn index(mut req: HttpRequest) -> Box> { + /// req.json() // <- get JsonBody future + /// .from_err() + /// .and_then(|val: MyObj| { // <- deserialized value + /// println!("==== BODY ==== {:?}", val); + /// Ok(httpcodes::HttpOk.into()) + /// }).responder() + /// } + /// # fn main() {} + /// ``` + fn json(self) -> JsonBody + where Self: Stream + Sized + { + JsonBody::new(self) + } + + /// Return stream to http payload processes as multipart. + /// + /// Content-type: multipart/form-data; + /// + /// ## Server example + /// + /// ```rust + /// # extern crate actix; + /// # extern crate actix_web; + /// # extern crate env_logger; + /// # extern crate futures; + /// # use std::str; + /// # use actix::*; + /// # use actix_web::*; + /// # use futures::{Future, Stream}; + /// # use futures::future::{ok, result, Either}; + /// fn index(mut req: HttpRequest) -> Box> { + /// req.multipart().from_err() // <- get multipart stream for current request + /// .and_then(|item| match item { // <- iterate over multipart items + /// multipart::MultipartItem::Field(field) => { + /// // Field in turn is stream of *Bytes* object + /// Either::A(field.from_err() + /// .map(|c| println!("-- CHUNK: \n{:?}", str::from_utf8(&c))) + /// .finish()) + /// }, + /// multipart::MultipartItem::Nested(mp) => { + /// // Or item could be nested Multipart stream + /// Either::B(ok(())) + /// } + /// }) + /// .finish() // <- Stream::finish() combinator from actix + /// .map(|_| httpcodes::HTTPOk.into()) + /// .responder() + /// } + /// # fn main() {} + /// ``` + fn multipart(self) -> Multipart + where Self: Stream + Sized + { + let boundary = Multipart::boundary(self.headers()); + Multipart::new(boundary, self) + } +} + +/// Future that resolves to a complete http message body. +pub struct MessageBody { + limit: usize, + req: Option, + fut: Option>>, +} + +impl MessageBody { + + /// Create `RequestBody` for request. + pub fn new(req: T) -> MessageBody { + MessageBody { + limit: 262_144, + req: Some(req), + fut: None, + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } +} + +impl Future for MessageBody + where T: HttpMessage + Stream + 'static +{ + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll { + if let Some(req) = self.req.take() { + if let Some(len) = req.headers().get(header::CONTENT_LENGTH) { + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + if len > self.limit { + return Err(PayloadError::Overflow); + } + } else { + return Err(PayloadError::UnknownLength); + } + } else { + return Err(PayloadError::UnknownLength); + } + } + + // future + let limit = self.limit; + self.fut = Some(Box::new( + req.from_err() + .fold(BytesMut::new(), move |mut body, chunk| { + if (body.len() + chunk.len()) > limit { + Err(PayloadError::Overflow) + } else { + body.extend_from_slice(&chunk); + Ok(body) + } + }) + .map(|body| body.freeze()) + )); + } + + self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() + } +} + +/// Future that resolves to a parsed urlencoded values. +pub struct UrlEncoded { + req: Option, + limit: usize, + fut: Option, Error=UrlencodedError>>>, +} + +impl UrlEncoded { + pub fn new(req: T) -> UrlEncoded { + UrlEncoded { + req: Some(req), + limit: 262_144, + fut: None, + } + } + + /// Change max size of payload. By default max size is 256Kb + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } +} + +impl Future for UrlEncoded + where T: HttpMessage + Stream + 'static +{ + type Item = HashMap; + type Error = UrlencodedError; + + fn poll(&mut self) -> Poll { + if let Some(req) = self.req.take() { + if req.chunked().unwrap_or(false) { + return Err(UrlencodedError::Chunked) + } else if let Some(len) = req.headers().get(header::CONTENT_LENGTH) { + if let Ok(s) = len.to_str() { + if let Ok(len) = s.parse::() { + if len > 262_144 { + return Err(UrlencodedError::Overflow); + } + } else { + return Err(UrlencodedError::UnknownLength) + } + } else { + return Err(UrlencodedError::UnknownLength) + } + } + + // check content type + if req.content_type().to_lowercase() != "application/x-www-form-urlencoded" { + return Err(UrlencodedError::ContentType) + } + let encoding = req.encoding().map_err(|_| UrlencodedError::ContentType)?; + + // future + let limit = self.limit; + let fut = req.from_err() + .fold(BytesMut::new(), move |mut body, chunk| { + if (body.len() + chunk.len()) > limit { + Err(UrlencodedError::Overflow) + } else { + body.extend_from_slice(&chunk); + Ok(body) + } + }) + .and_then(move |body| { + let mut m = HashMap::new(); + let parsed = form_urlencoded::parse_with_encoding( + &body, Some(encoding), false).map_err(|_| UrlencodedError::Parse)?; + for (k, v) in parsed { + m.insert(k.into(), v.into()); + } + Ok(m) + }); + self.fut = Some(Box::new(fut)); + } + + self.fut.as_mut().expect("UrlEncoded could not be used second time").poll() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use mime; + use encoding::Encoding; + use encoding::all::ISO_8859_2; + use futures::Async; + use http::{Method, Version, Uri}; + use httprequest::HttpRequest; + use std::str::FromStr; + use std::iter::FromIterator; + use test::TestRequest; + + #[test] + fn test_content_type() { + let req = TestRequest::with_header("content-type", "text/plain").finish(); + assert_eq!(req.content_type(), "text/plain"); + let req = TestRequest::with_header( + "content-type", "application/json; charset=utf=8").finish(); + assert_eq!(req.content_type(), "application/json"); + let req = HttpRequest::default(); + assert_eq!(req.content_type(), ""); + } + + #[test] + fn test_mime_type() { + let req = TestRequest::with_header("content-type", "application/json").finish(); + assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON)); + let req = HttpRequest::default(); + assert_eq!(req.mime_type().unwrap(), None); + let req = TestRequest::with_header( + "content-type", "application/json; charset=utf-8").finish(); + let mt = req.mime_type().unwrap().unwrap(); + assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8)); + assert_eq!(mt.type_(), mime::APPLICATION); + assert_eq!(mt.subtype(), mime::JSON); + } + + #[test] + fn test_mime_type_error() { + let req = TestRequest::with_header( + "content-type", "applicationadfadsfasdflknadsfklnadsfjson").finish(); + assert_eq!(Err(ContentTypeError::ParseError), req.mime_type()); + } + + #[test] + fn test_encoding() { + let req = HttpRequest::default(); + assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); + + let req = TestRequest::with_header( + "content-type", "application/json").finish(); + assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); + + let req = TestRequest::with_header( + "content-type", "application/json; charset=ISO-8859-2").finish(); + assert_eq!(ISO_8859_2.name(), req.encoding().unwrap().name()); + } + + #[test] + fn test_encoding_error() { + let req = TestRequest::with_header( + "content-type", "applicatjson").finish(); + assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err()); + + let req = TestRequest::with_header( + "content-type", "application/json; charset=kkkttktk").finish(); + assert_eq!(Some(ContentTypeError::UnknownEncoding), req.encoding().err()); + } + + #[test] + fn test_no_request_range_header() { + let req = HttpRequest::default(); + let ranges = req.range(100).unwrap(); + assert!(ranges.is_empty()); + } + + #[test] + fn test_request_range_header() { + let req = TestRequest::with_header(header::RANGE, "bytes=0-4").finish(); + let ranges = req.range(100).unwrap(); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0].start, 0); + assert_eq!(ranges[0].length, 5); + } + + #[test] + fn test_chunked() { + let req = HttpRequest::default(); + assert!(!req.chunked().unwrap()); + + let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); + assert!(req.chunked().unwrap()); + + let mut headers = HeaderMap::new(); + let s = unsafe{str::from_utf8_unchecked(b"some va\xadscc\xacas0xsdasdlue".as_ref())}; + + headers.insert(header::TRANSFER_ENCODING, + header::HeaderValue::from_str(s).unwrap()); + let req = HttpRequest::new( + Method::GET, Uri::from_str("/").unwrap(), + Version::HTTP_11, headers, None); + assert!(req.chunked().is_err()); + } + + impl PartialEq for UrlencodedError { + fn eq(&self, other: &UrlencodedError) -> bool { + match *self { + UrlencodedError::Chunked => match *other { + UrlencodedError::Chunked => true, + _ => false, + }, + UrlencodedError::Overflow => match *other { + UrlencodedError::Overflow => true, + _ => false, + }, + UrlencodedError::UnknownLength => match *other { + UrlencodedError::UnknownLength => true, + _ => false, + }, + UrlencodedError::ContentType => match *other { + UrlencodedError::ContentType => true, + _ => false, + }, + _ => false, + } + } + } + + #[test] + fn test_urlencoded_error() { + let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); + assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::Chunked); + + let req = TestRequest::with_header( + header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .header(header::CONTENT_LENGTH, "xxxx") + .finish(); + assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::UnknownLength); + + let req = TestRequest::with_header( + header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .header(header::CONTENT_LENGTH, "1000000") + .finish(); + assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::Overflow); + + let req = TestRequest::with_header( + header::CONTENT_TYPE, "text/plain") + .header(header::CONTENT_LENGTH, "10") + .finish(); + assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::ContentType); + } + + #[test] + fn test_urlencoded() { + let mut req = TestRequest::with_header( + header::CONTENT_TYPE, "application/x-www-form-urlencoded") + .header(header::CONTENT_LENGTH, "11") + .finish(); + req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); + + let result = req.urlencoded().poll().ok().unwrap(); + assert_eq!(result, Async::Ready( + HashMap::from_iter(vec![("hello".to_owned(), "world".to_owned())]))); + + let mut req = TestRequest::with_header( + header::CONTENT_TYPE, "application/x-www-form-urlencoded; charset=utf-8") + .header(header::CONTENT_LENGTH, "11") + .finish(); + req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); + + let result = req.urlencoded().poll().ok().unwrap(); + assert_eq!(result, Async::Ready( + HashMap::from_iter(vec![("hello".to_owned(), "world".to_owned())]))); + } + + #[test] + fn test_message_body() { + let req = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx").finish(); + match req.body().poll().err().unwrap() { + PayloadError::UnknownLength => (), + _ => panic!("error"), + } + + let req = TestRequest::with_header(header::CONTENT_LENGTH, "1000000").finish(); + match req.body().poll().err().unwrap() { + PayloadError::Overflow => (), + _ => panic!("error"), + } + + let mut req = HttpRequest::default(); + req.payload_mut().unread_data(Bytes::from_static(b"test")); + match req.body().poll().ok().unwrap() { + Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")), + _ => panic!("error"), + } + + let mut req = HttpRequest::default(); + req.payload_mut().unread_data(Bytes::from_static(b"11111111111111")); + match req.body().limit(5).poll().err().unwrap() { + PayloadError::Overflow => (), + _ => panic!("error"), + } + } +} diff --git a/src/httprequest.rs b/src/httprequest.rs index aa0eb9f0e..688bea7a4 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -1,29 +1,25 @@ //! HTTP Request message related code. -use std::{str, fmt, mem}; +use std::{io, cmp, str, fmt, mem}; use std::rc::Rc; use std::net::SocketAddr; -use std::collections::HashMap; -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use cookie::Cookie; -use futures::{Async, Future, Stream, Poll}; -use http_range::HttpRange; -use serde::de::DeserializeOwned; -use mime::Mime; +use futures::{Async, Stream, Poll}; +use failure; use url::{Url, form_urlencoded}; use http::{header, Uri, Method, Version, HeaderMap, Extensions}; +use tokio_io::AsyncRead; use info::ConnectionInfo; use param::Params; use router::Router; -use payload::{Payload, ReadAny}; -use json::JsonBody; -use multipart::Multipart; -use helpers::SharedHttpMessage; -use error::{ParseError, UrlGenerationError, - CookieParseError, HttpRangeError, PayloadError, UrlencodedError}; +use payload::Payload; +use httpmessage::HttpMessage; +use helpers::SharedHttpInnerMessage; +use error::{UrlGenerationError, CookieParseError, PayloadError}; -pub struct HttpMessage { +pub struct HttpInnerMessage { pub version: Version, pub method: Method, pub uri: Uri, @@ -38,10 +34,10 @@ pub struct HttpMessage { pub info: Option>, } -impl Default for HttpMessage { +impl Default for HttpInnerMessage { - fn default() -> HttpMessage { - HttpMessage { + fn default() -> HttpInnerMessage { + HttpInnerMessage { method: Method::GET, uri: Uri::default(), version: Version::HTTP_11, @@ -58,7 +54,7 @@ impl Default for HttpMessage { } } -impl HttpMessage { +impl HttpInnerMessage { /// Checks if a connection should be kept alive. #[inline] @@ -94,7 +90,7 @@ impl HttpMessage { } /// An HTTP Request -pub struct HttpRequest(SharedHttpMessage, Option>, Option); +pub struct HttpRequest(SharedHttpInnerMessage, Option>, Option); impl HttpRequest<()> { /// Construct a new Request. @@ -103,17 +99,17 @@ impl HttpRequest<()> { version: Version, headers: HeaderMap, payload: Option) -> HttpRequest { HttpRequest( - SharedHttpMessage::from_message(HttpMessage { - method: method, - uri: uri, - version: version, - headers: headers, + SharedHttpInnerMessage::from_message(HttpInnerMessage { + method, + uri, + version, + headers, + payload, params: Params::new(), query: Params::new(), query_loaded: false, cookies: None, addr: None, - payload: payload, extensions: Extensions::new(), info: None, }), @@ -124,7 +120,7 @@ impl HttpRequest<()> { #[inline(always)] #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - pub(crate) fn from_message(msg: SharedHttpMessage) -> HttpRequest { + pub(crate) fn from_message(msg: SharedHttpInnerMessage) -> HttpRequest { HttpRequest(msg, None, None) } @@ -141,6 +137,14 @@ impl HttpRequest<()> { } } + +impl HttpMessage for HttpRequest { + #[inline] + fn headers(&self) -> &HeaderMap { + &self.as_ref().headers + } +} + impl HttpRequest { #[inline] @@ -151,26 +155,26 @@ impl HttpRequest { #[inline] /// Construct new http request without state. - pub(crate) fn clone_without_state(&self) -> HttpRequest { + pub(crate) fn without_state(&self) -> HttpRequest { HttpRequest(self.0.clone(), None, None) } - // get mutable reference for inner message - // mutable reference should not be returned as result for request's method + /// get mutable reference for inner message + /// mutable reference should not be returned as result for request's method #[inline(always)] #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] - pub(crate) fn as_mut(&self) -> &mut HttpMessage { + pub(crate) fn as_mut(&self) -> &mut HttpInnerMessage { self.0.get_mut() } #[inline(always)] #[cfg_attr(feature = "cargo-clippy", allow(mut_from_ref, inline_always))] - fn as_ref(&self) -> &HttpMessage { + fn as_ref(&self) -> &HttpInnerMessage { self.0.get_ref() } #[inline] - pub(crate) fn get_inner(&mut self) -> &mut HttpMessage { + pub(crate) fn get_inner(&mut self) -> &mut HttpInnerMessage { self.as_mut() } @@ -214,12 +218,6 @@ impl HttpRequest { self.as_ref().version } - /// Read the Request Headers. - #[inline] - pub fn headers(&self) -> &HeaderMap { - &self.as_ref().headers - } - #[doc(hidden)] #[inline] pub fn headers_mut(&mut self) -> &mut HeaderMap { @@ -251,14 +249,14 @@ impl HttpRequest { /// # /// fn index(req: HttpRequest) -> HttpResponse { /// let url = req.url_for("foo", &["1", "2", "3"]); // <- generate url for "foo" resource - /// HTTPOk.into() + /// HttpOk.into() /// } /// /// fn main() { /// let app = Application::new() /// .resource("/test/{one}/{two}/{three}", |r| { /// r.name("foo"); // <- set resource name, then it could be used in `url_for` - /// r.method(Method::GET).f(|_| httpcodes::HTTPOk); + /// r.method(Method::GET).f(|_| httpcodes::HttpOk); /// }) /// .finish(); /// } @@ -367,7 +365,7 @@ impl HttpRequest { /// Get mutable reference to request's Params. #[inline] - pub(crate) fn match_info_mut(&mut self) -> &mut Params { + pub fn match_info_mut(&mut self) -> &mut Params { unsafe{ mem::transmute(&mut self.as_mut().params) } } @@ -376,30 +374,6 @@ impl HttpRequest { self.as_ref().keep_alive() } - /// Read the request content type. If request does not contain - /// *Content-Type* header, empty str get returned. - pub fn content_type(&self) -> &str { - if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - return content_type.split(';').next().unwrap().trim() - } - } - "" - } - - /// Convert the request content type to a known mime type. - pub fn mime_type(&self) -> Option { - if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - return match content_type.parse() { - Ok(mt) => Some(mt), - Err(_) => None - }; - } - } - None - } - /// Check if request requires connection upgrade pub(crate) fn upgrade(&self) -> bool { if let Some(conn) = self.as_ref().headers.get(header::CONNECTION) { @@ -410,33 +384,8 @@ impl HttpRequest { self.as_ref().method == Method::CONNECT } - /// Check if request has chunked transfer encoding - pub fn chunked(&self) -> Result { - if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { - if let Ok(s) = encodings.to_str() { - Ok(s.to_lowercase().contains("chunked")) - } else { - Err(ParseError::Header) - } - } else { - Ok(false) - } - } - - /// Parses Range HTTP header string as per RFC 2616. - /// `size` is full size of response (file). - pub fn range(&self, size: u64) -> Result, HttpRangeError> { - if let Some(range) = self.headers().get(header::RANGE) { - HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size) - .map_err(|e| e.into()) - } else { - Ok(Vec::new()) - } - } - - /// Returns reference to the associated http payload. - #[inline] - pub fn payload(&self) -> &Payload { + #[cfg(test)] + pub(crate) fn payload(&self) -> &Payload { let msg = self.as_mut(); if msg.payload.is_none() { msg.payload = Some(Payload::empty()); @@ -444,157 +393,21 @@ impl HttpRequest { msg.payload.as_ref().unwrap() } - /// Returns mutable reference to the associated http payload. - #[inline] - pub fn payload_mut(&mut self) -> &mut Payload { + #[cfg(test)] + pub(crate) fn payload_mut(&mut self) -> &mut Payload { let msg = self.as_mut(); if msg.payload.is_none() { msg.payload = Some(Payload::empty()); } msg.payload.as_mut().unwrap() } - - /// Load request body. - /// - /// By default only 256Kb payload reads to a memory, then `BAD REQUEST` - /// http response get returns to a peer. Use `RequestBody::limit()` - /// method to change upper limit. - /// - /// ```rust - /// # extern crate bytes; - /// # extern crate actix_web; - /// # extern crate futures; - /// # #[macro_use] extern crate serde_derive; - /// use actix_web::*; - /// use bytes::Bytes; - /// use futures::future::Future; - /// - /// fn index(mut req: HttpRequest) -> Box> { - /// req.body() // <- get Body future - /// .limit(1024) // <- change max size of the body to a 1kb - /// .from_err() - /// .and_then(|bytes: Bytes| { // <- complete body - /// println!("==== BODY ==== {:?}", bytes); - /// Ok(httpcodes::HTTPOk.into()) - /// }).responder() - /// } - /// # fn main() {} - /// ``` - pub fn body(&self) -> RequestBody { - RequestBody::from_request(self) - } - - /// Return stream to http payload processes as multipart. - /// - /// Content-type: multipart/form-data; - /// - /// ```rust - /// # extern crate actix; - /// # extern crate actix_web; - /// # extern crate env_logger; - /// # extern crate futures; - /// # use std::str; - /// # use actix::*; - /// # use actix_web::*; - /// # use futures::{Future, Stream}; - /// # use futures::future::{ok, result, Either}; - /// fn index(mut req: HttpRequest) -> Box> { - /// req.multipart().from_err() // <- get multipart stream for current request - /// .and_then(|item| match item { // <- iterate over multipart items - /// multipart::MultipartItem::Field(field) => { - /// // Field in turn is stream of *Bytes* object - /// Either::A(field.from_err() - /// .map(|c| println!("-- CHUNK: \n{:?}", str::from_utf8(&c))) - /// .finish()) - /// }, - /// multipart::MultipartItem::Nested(mp) => { - /// // Or item could be nested Multipart stream - /// Either::B(ok(())) - /// } - /// }) - /// .finish() // <- Stream::finish() combinator from actix - /// .map(|_| httpcodes::HTTPOk.into()) - /// .responder() - /// } - /// # fn main() {} - /// ``` - pub fn multipart(&mut self) -> Multipart { - Multipart::from_request(self) - } - - /// Parse `application/x-www-form-urlencoded` encoded body. - /// Return `UrlEncoded` future. It resolves to a `HashMap` which - /// contains decoded parameters. - /// - /// Returns error: - /// - /// * content type is not `application/x-www-form-urlencoded` - /// * transfer encoding is `chunked`. - /// * content-length is greater than 256k - /// - /// ```rust - /// # extern crate actix_web; - /// # extern crate futures; - /// use actix_web::*; - /// use futures::future::{Future, ok}; - /// - /// fn index(mut req: HttpRequest) -> Box> { - /// req.urlencoded() // <- get UrlEncoded future - /// .from_err() - /// .and_then(|params| { // <- url encoded parameters - /// println!("==== BODY ==== {:?}", params); - /// ok(httpcodes::HTTPOk.into()) - /// }) - /// .responder() - /// } - /// # fn main() {} - /// ``` - pub fn urlencoded(&self) -> UrlEncoded { - UrlEncoded::from(self.payload().clone(), - self.headers(), - self.chunked().unwrap_or(false)) - } - - /// Parse `application/json` encoded body. - /// Return `JsonBody` future. It resolves to a `T` value. - /// - /// Returns error: - /// - /// * content type is not `application/json` - /// * content length is greater than 256k - /// - /// ```rust - /// # extern crate actix_web; - /// # extern crate futures; - /// # #[macro_use] extern crate serde_derive; - /// use actix_web::*; - /// use futures::future::{Future, ok}; - /// - /// #[derive(Deserialize, Debug)] - /// struct MyObj { - /// name: String, - /// } - /// - /// fn index(mut req: HttpRequest) -> Box> { - /// req.json() // <- get JsonBody future - /// .from_err() - /// .and_then(|val: MyObj| { // <- deserialized value - /// println!("==== BODY ==== {:?}", val); - /// Ok(httpcodes::HTTPOk.into()) - /// }).responder() - /// } - /// # fn main() {} - /// ``` - pub fn json(&self) -> JsonBody { - JsonBody::from_request(self) - } } impl Default for HttpRequest<()> { /// Construct default request fn default() -> HttpRequest { - HttpRequest(SharedHttpMessage::default(), None, None) + HttpRequest(SharedHttpInnerMessage::default(), None, None) } } @@ -604,6 +417,56 @@ impl Clone for HttpRequest { } } +impl Stream for HttpRequest { + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll, PayloadError> { + let msg = self.as_mut(); + if msg.payload.is_none() { + Ok(Async::Ready(None)) + } else { + msg.payload.as_mut().unwrap().poll() + } + } +} + +impl io::Read for HttpRequest { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if self.as_mut().payload.is_some() { + match self.as_mut().payload.as_mut().unwrap().poll() { + Ok(Async::Ready(Some(mut b))) => { + let i = cmp::min(b.len(), buf.len()); + buf.copy_from_slice(&b.split_to(i)[..i]); + + if !b.is_empty() { + self.as_mut().payload.as_mut().unwrap().unread_data(b); + } + + if i < buf.len() { + match self.read(&mut buf[i..]) { + Ok(n) => Ok(i + n), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(i), + Err(e) => Err(e), + } + } else { + Ok(i) + } + } + Ok(Async::Ready(None)) => Ok(0), + Ok(Async::NotReady) => + Err(io::Error::new(io::ErrorKind::WouldBlock, "Not ready")), + Err(e) => + Err(io::Error::new(io::ErrorKind::Other, failure::Error::from(e).compat())), + } + } else { + Ok(0) + } + } +} + +impl AsyncRead for HttpRequest {} + impl fmt::Debug for HttpRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = write!(f, "\nHttpRequest {:?} {}:{}\n", @@ -627,158 +490,10 @@ impl fmt::Debug for HttpRequest { } } -/// Future that resolves to a parsed urlencoded values. -pub struct UrlEncoded { - pl: Payload, - body: BytesMut, - error: Option, -} - -impl UrlEncoded { - pub fn from(pl: Payload, headers: &HeaderMap, chunked: bool) -> UrlEncoded { - let mut encoded = UrlEncoded { - pl: pl, - body: BytesMut::new(), - error: None - }; - - if chunked { - encoded.error = Some(UrlencodedError::Chunked); - } else if let Some(len) = headers.get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > 262_144 { - encoded.error = Some(UrlencodedError::Overflow); - } - } else { - encoded.error = Some(UrlencodedError::UnknownLength); - } - } else { - encoded.error = Some(UrlencodedError::UnknownLength); - } - } - - // check content type - if encoded.error.is_none() { - if let Some(content_type) = headers.get(header::CONTENT_TYPE) { - if let Ok(content_type) = content_type.to_str() { - if content_type.to_lowercase() == "application/x-www-form-urlencoded" { - return encoded - } - } - } - encoded.error = Some(UrlencodedError::ContentType); - return encoded - } - - encoded - } -} - -impl Future for UrlEncoded { - type Item = HashMap; - type Error = UrlencodedError; - - fn poll(&mut self) -> Poll { - if let Some(err) = self.error.take() { - return Err(err) - } - - loop { - return match self.pl.poll() { - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(None)) => { - let mut m = HashMap::new(); - for (k, v) in form_urlencoded::parse(&self.body) { - m.insert(k.into(), v.into()); - } - Ok(Async::Ready(m)) - }, - Ok(Async::Ready(Some(item))) => { - self.body.extend_from_slice(&item); - continue - }, - Err(err) => Err(err.into()), - } - } - } -} - -/// Future that resolves to a complete request body. -pub struct RequestBody { - pl: ReadAny, - body: BytesMut, - limit: usize, - req: Option>, -} - -impl RequestBody { - - /// Create `RequestBody` for request. - pub fn from_request(req: &HttpRequest) -> RequestBody { - let pl = req.payload().readany(); - RequestBody { - pl: pl, - body: BytesMut::new(), - limit: 262_144, - req: Some(req.clone_without_state()) - } - } - - /// Change max size of payload. By default max size is 256Kb - pub fn limit(mut self, limit: usize) -> Self { - self.limit = limit; - self - } -} - -impl Future for RequestBody { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll { - if let Some(req) = self.req.take() { - if let Some(len) = req.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = len.to_str() { - if let Ok(len) = s.parse::() { - if len > self.limit { - return Err(PayloadError::Overflow); - } - } else { - return Err(PayloadError::UnknownLength); - } - } else { - return Err(PayloadError::UnknownLength); - } - } - } - - loop { - return match self.pl.poll() { - Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(None)) => { - Ok(Async::Ready(self.body.take().freeze())) - }, - Ok(Async::Ready(Some(chunk))) => { - if (self.body.len() + chunk.len()) > self.limit { - Err(PayloadError::Overflow) - } else { - self.body.extend_from_slice(&chunk); - continue - } - }, - Err(err) => Err(err), - } - } - } -} - #[cfg(test)] mod tests { use super::*; - use mime; use http::{Uri, HttpTryFrom}; - use std::str::FromStr; use router::Pattern; use resource::Resource; use test::TestRequest; @@ -791,31 +506,6 @@ mod tests { assert!(dbg.contains("HttpRequest")); } - #[test] - fn test_content_type() { - let req = TestRequest::with_header("content-type", "text/plain").finish(); - assert_eq!(req.content_type(), "text/plain"); - let req = TestRequest::with_header( - "content-type", "application/json; charset=utf=8").finish(); - assert_eq!(req.content_type(), "application/json"); - let req = HttpRequest::default(); - assert_eq!(req.content_type(), ""); - } - - #[test] - fn test_mime_type() { - let req = TestRequest::with_header("content-type", "application/json").finish(); - assert_eq!(req.mime_type(), Some(mime::APPLICATION_JSON)); - let req = HttpRequest::default(); - assert_eq!(req.mime_type(), None); - let req = TestRequest::with_header( - "content-type", "application/json; charset=utf-8").finish(); - let mt = req.mime_type().unwrap(); - assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8)); - assert_eq!(mt.type_(), mime::APPLICATION); - assert_eq!(mt.subtype(), mime::JSON); - } - #[test] fn test_uri_mut() { let mut req = HttpRequest::default(); @@ -853,22 +543,6 @@ mod tests { assert!(cookie.is_none()); } - #[test] - fn test_no_request_range_header() { - let req = HttpRequest::default(); - let ranges = req.range(100).unwrap(); - assert!(ranges.is_empty()); - } - - #[test] - fn test_request_range_header() { - let req = TestRequest::with_header(header::RANGE, "bytes=0-4").finish(); - let ranges = req.range(100).unwrap(); - assert_eq!(ranges.len(), 1); - assert_eq!(ranges[0].start, 0); - assert_eq!(ranges[0].length, 5); - } - #[test] fn test_request_query() { let req = TestRequest::with_uri("/?id=test").finish(); @@ -891,102 +565,6 @@ mod tests { assert_eq!(req.match_info().get("key"), Some("value")); } - #[test] - fn test_chunked() { - let req = HttpRequest::default(); - assert!(!req.chunked().unwrap()); - - let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); - assert!(req.chunked().unwrap()); - - let mut headers = HeaderMap::new(); - let s = unsafe{str::from_utf8_unchecked(b"some va\xadscc\xacas0xsdasdlue".as_ref())}; - - headers.insert(header::TRANSFER_ENCODING, - header::HeaderValue::from_str(s).unwrap()); - let req = HttpRequest::new( - Method::GET, Uri::from_str("/").unwrap(), - Version::HTTP_11, headers, None); - assert!(req.chunked().is_err()); - } - - impl PartialEq for UrlencodedError { - fn eq(&self, other: &UrlencodedError) -> bool { - match *self { - UrlencodedError::Chunked => match *other { - UrlencodedError::Chunked => true, - _ => false, - }, - UrlencodedError::Overflow => match *other { - UrlencodedError::Overflow => true, - _ => false, - }, - UrlencodedError::UnknownLength => match *other { - UrlencodedError::UnknownLength => true, - _ => false, - }, - UrlencodedError::ContentType => match *other { - UrlencodedError::ContentType => true, - _ => false, - }, - _ => false, - } - } - } - - #[test] - fn test_urlencoded_error() { - let req = TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); - assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::Chunked); - - let req = TestRequest::with_header( - header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(header::CONTENT_LENGTH, "xxxx") - .finish(); - assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::UnknownLength); - - let req = TestRequest::with_header( - header::CONTENT_TYPE, "application/x-www-form-urlencoded") - .header(header::CONTENT_LENGTH, "1000000") - .finish(); - assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::Overflow); - - let req = TestRequest::with_header( - header::CONTENT_TYPE, "text/plain") - .header(header::CONTENT_LENGTH, "10") - .finish(); - assert_eq!(req.urlencoded().poll().err().unwrap(), UrlencodedError::ContentType); - } - - #[test] - fn test_request_body() { - let req = TestRequest::with_header(header::CONTENT_LENGTH, "xxxx").finish(); - match req.body().poll().err().unwrap() { - PayloadError::UnknownLength => (), - _ => panic!("error"), - } - - let req = TestRequest::with_header(header::CONTENT_LENGTH, "1000000").finish(); - match req.body().poll().err().unwrap() { - PayloadError::Overflow => (), - _ => panic!("error"), - } - - let mut req = HttpRequest::default(); - req.payload_mut().unread_data(Bytes::from_static(b"test")); - match req.body().poll().ok().unwrap() { - Async::Ready(bytes) => assert_eq!(bytes, Bytes::from_static(b"test")), - _ => panic!("error"), - } - - let mut req = HttpRequest::default(); - req.payload_mut().unread_data(Bytes::from_static(b"11111111111111")); - match req.body().limit(5).poll().err().unwrap() { - PayloadError::Overflow => (), - _ => panic!("error"), - } - } - #[test] fn test_url_for() { let req = TestRequest::with_header(header::HOST, "www.rust-lang.org") diff --git a/src/httpresponse.rs b/src/httpresponse.rs index 77b63f125..9af932b12 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -252,7 +252,7 @@ impl HttpResponseBuilder { /// use http::header; /// /// fn index(req: HttpRequest) -> Result { - /// Ok(HTTPOk.build() + /// Ok(HttpOk.build() /// .header("X-TEST", "value") /// .header(header::CONTENT_TYPE, "application/json") /// .finish()?) @@ -372,7 +372,7 @@ impl HttpResponseBuilder { /// use actix_web::headers::Cookie; /// /// fn index(req: HttpRequest) -> Result { - /// Ok(HTTPOk.build() + /// Ok(HttpOk.build() /// .cookie( /// Cookie::build("name", "value") /// .domain("www.rust-lang.org") @@ -659,11 +659,11 @@ impl InnerHttpResponse { #[inline] fn new(status: StatusCode, body: Body) -> InnerHttpResponse { InnerHttpResponse { + status, + body, version: None, headers: HeaderMap::with_capacity(16), - status: status, reason: None, - body: body, chunked: None, encoding: None, connection_type: None, @@ -753,7 +753,7 @@ mod tests { Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); let cookies = req.cookies().unwrap(); - let resp = httpcodes::HTTPOk + let resp = httpcodes::HttpOk .build() .cookie(headers::Cookie::build("name", "value") .domain("www.rust-lang.org") diff --git a/src/info.rs b/src/info.rs index 92ffa4d6c..6177cd021 100644 --- a/src/info.rs +++ b/src/info.rs @@ -1,5 +1,6 @@ use std::str::FromStr; use http::header::{self, HeaderName}; +use httpmessage::HttpMessage; use httprequest::HttpRequest; const X_FORWARDED_FOR: &str = "X-FORWARDED-FOR"; @@ -110,8 +111,8 @@ impl<'a> ConnectionInfo<'a> { ConnectionInfo { scheme: scheme.unwrap_or("http"), host: host.unwrap_or("localhost"), - remote: remote, - peer: peer, + remote, + peer, } } diff --git a/src/json.rs b/src/json.rs index 8bcda5c90..a41125b41 100644 --- a/src/json.rs +++ b/src/json.rs @@ -1,4 +1,4 @@ -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::{Poll, Future, Stream}; use http::header::CONTENT_LENGTH; @@ -6,8 +6,9 @@ use serde_json; use serde::Serialize; use serde::de::DeserializeOwned; -use error::{Error, JsonPayloadError}; +use error::{Error, JsonPayloadError, PayloadError}; use handler::Responder; +use httpmessage::HttpMessage; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -54,6 +55,9 @@ impl Responder for Json { /// * content type is not `application/json` /// * content length is greater than 256k /// +/// +/// # Server example +/// /// ```rust /// # extern crate actix_web; /// # extern crate futures; @@ -71,25 +75,25 @@ impl Responder for Json { /// .from_err() /// .and_then(|val: MyObj| { // <- deserialized value /// println!("==== BODY ==== {:?}", val); -/// Ok(httpcodes::HTTPOk.into()) +/// Ok(httpcodes::HttpOk.into()) /// }).responder() /// } /// # fn main() {} /// ``` -pub struct JsonBody{ +pub struct JsonBody{ limit: usize, ct: &'static str, - req: Option>, - fut: Option>>, + req: Option, + fut: Option>>, } -impl JsonBody { +impl JsonBody { /// Create `JsonBody` for request. - pub fn from_request(req: &HttpRequest) -> Self { + pub fn new(req: T) -> Self { JsonBody{ limit: 262_144, - req: Some(req.clone()), + req: Some(req), fut: None, ct: "application/json", } @@ -111,11 +115,13 @@ impl JsonBody { } } -impl Future for JsonBody { - type Item = T; +impl Future for JsonBody + where T: HttpMessage + Stream + 'static +{ + type Item = U; type Error = JsonPayloadError; - fn poll(&mut self) -> Poll { + fn poll(&mut self) -> Poll { if let Some(req) = self.req.take() { if let Some(len) = req.headers().get(CONTENT_LENGTH) { if let Ok(s) = len.to_str() { @@ -134,8 +140,7 @@ impl Future for JsonBody { } let limit = self.limit; - let fut = req.payload().readany() - .from_err() + let fut = req.from_err() .fold(BytesMut::new(), move |mut body, chunk| { if (body.len() + chunk.len()) > limit { Err(JsonPayloadError::Overflow) @@ -144,7 +149,7 @@ impl Future for JsonBody { Ok(body) } }) - .and_then(|body| Ok(serde_json::from_slice::(&body)?)); + .and_then(|body| Ok(serde_json::from_slice::(&body)?)); self.fut = Some(Box::new(fut)); } @@ -189,27 +194,31 @@ mod tests { #[test] fn test_json_body() { - let mut req = HttpRequest::default(); + let req = HttpRequest::default(); let mut json = req.json::(); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); - let mut json = req.json::().content_type("text/json"); + let mut req = HttpRequest::default(); req.headers_mut().insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json")); + let mut json = req.json::().content_type("text/json"); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::ContentType); - let mut json = req.json::().limit(100); + let mut req = HttpRequest::default(); req.headers_mut().insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json")); req.headers_mut().insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("10000")); + let mut json = req.json::().limit(100); assert_eq!(json.poll().err().unwrap(), JsonPayloadError::Overflow); + let mut req = HttpRequest::default(); + req.headers_mut().insert(header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json")); req.headers_mut().insert(header::CONTENT_LENGTH, header::HeaderValue::from_static("16")); req.payload_mut().unread_data(Bytes::from_static(b"{\"name\": \"test\"}")); let mut json = req.json::(); assert_eq!(json.poll().ok().unwrap(), Async::Ready(MyObject{name: "test".to_owned()})); } - } diff --git a/src/lib.rs b/src/lib.rs index c9819aef3..f3decb149 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -//! Actix web is a small, fast, pragmatic, open source rust web framework. +//! Actix web is a small, pragmatic, extremely fast, web framework for Rust. //! //! ```rust //! use actix_web::*; @@ -32,19 +32,20 @@ //! * Supported *HTTP/1.x* and *HTTP/2.0* protocols //! * Streaming and pipelining //! * Keep-alive and slow requests handling -//! * `WebSockets` +//! * `WebSockets` server/client //! * Transparent content compression/decompression (br, gzip, deflate) //! * Configurable request routing -//! * Multipart streams -//! * Middlewares (`Logger`, `Session`, `DefaultHeaders`) //! * Graceful server shutdown -//! * Built on top of [Actix](https://github.com/actix/actix). +//! * Multipart streams +//! * SSL support with openssl or native-tls +//! * Middlewares (`Logger`, `Session`, `CORS`, `DefaultHeaders`) +//! * Built on top of [Actix actor framework](https://github.com/actix/actix). #![cfg_attr(actix_nightly, feature( specialization, // for impl ErrorResponse for std::error::Error ))] #![cfg_attr(feature = "cargo-clippy", allow( - decimal_literal_representation,))] + decimal_literal_representation,suspicious_arithmetic_impl,))] #[macro_use] extern crate log; @@ -77,6 +78,7 @@ extern crate serde; extern crate serde_json; extern crate flate2; extern crate brotli2; +extern crate encoding; extern crate percent_encoding; extern crate smallvec; extern crate num_cpus; @@ -100,16 +102,18 @@ extern crate tokio_openssl; mod application; mod body; mod context; +mod handler; mod helpers; +mod httpmessage; mod httprequest; mod httpresponse; mod info; mod json; mod route; mod router; -mod param; mod resource; -mod handler; +mod param; +mod payload; mod pipeline; pub mod client; @@ -121,12 +125,12 @@ pub mod multipart; pub mod middleware; pub mod pred; pub mod test; -pub mod payload; pub mod server; pub use error::{Error, Result, ResponseError}; pub use body::{Body, Binary}; pub use json::Json; pub use application::Application; +pub use httpmessage::HttpMessage; pub use httprequest::HttpRequest; pub use httpresponse::HttpResponse; pub use handler::{Reply, Responder, NormalizePath, AsyncResponder}; @@ -185,11 +189,12 @@ pub mod dev { //! ``` pub use body::BodyStream; + pub use context::Drain; pub use info::ConnectionInfo; pub use handler::Handler; pub use json::JsonBody; pub use router::{Router, Pattern}; pub use param::{FromParam, Params}; - pub use httprequest::{UrlEncoded, RequestBody}; + pub use httpmessage::{UrlEncoded, MessageBody}; pub use httpresponse::HttpResponseBuilder; } diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 748ab1bba..25ae747ce 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -38,8 +38,8 @@ //! .max_age(3600) //! .finish().expect("Can not create CORS middleware") //! .register(r); // <- Register CORS middleware -//! r.method(Method::GET).f(|_| httpcodes::HTTPOk); -//! r.method(Method::HEAD).f(|_| httpcodes::HTTPMethodNotAllowed); +//! r.method(Method::GET).f(|_| httpcodes::HttpOk); +//! r.method(Method::HEAD).f(|_| httpcodes::HttpMethodNotAllowed); //! }) //! .finish(); //! } @@ -55,9 +55,10 @@ use http::header::{self, HeaderName, HeaderValue}; use error::{Result, ResponseError}; use resource::Resource; +use httpmessage::HttpMessage; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use httpcodes::{HTTPOk, HTTPBadRequest}; +use httpcodes::{HttpOk, HttpBadRequest}; use middleware::{Middleware, Response, Started}; /// A set of errors that can occur during processing CORS @@ -109,7 +110,7 @@ pub enum CorsBuilderError { impl ResponseError for CorsError { fn error_response(&self) -> HttpResponse { - HTTPBadRequest.build().body(format!("{}", self)).unwrap() + HttpBadRequest.build().body(format!("{}", self)).unwrap() } } @@ -218,7 +219,7 @@ impl Cors { /// method, but in that case *Cors* middleware wont be able to handle *OPTIONS* /// requests. pub fn register(self, resource: &mut Resource) { - resource.method(Method::OPTIONS).h(HTTPOk); + resource.method(Method::OPTIONS).h(HttpOk); resource.middleware(self); } @@ -306,7 +307,7 @@ impl Middleware for Cors { }; Ok(Started::Response( - HTTPOk.build() + HttpOk.build() .if_some(self.max_age.as_ref(), |max_age, resp| { let _ = resp.header( header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());}) @@ -822,7 +823,7 @@ mod tests { .method(Method::OPTIONS) .finish(); - let resp: HttpResponse = HTTPOk.into(); + let resp: HttpResponse = HttpOk.into(); let resp = cors.response(&mut req, resp).unwrap().response(); assert_eq!( &b"*"[..], @@ -831,7 +832,7 @@ mod tests { &b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes()); - let resp: HttpResponse = HTTPOk.build() + let resp: HttpResponse = HttpOk.build() .header(header::VARY, "Accept") .finish().unwrap(); let resp = cors.response(&mut req, resp).unwrap().response(); @@ -843,7 +844,7 @@ mod tests { .disable_vary_header() .allowed_origin("https://www.example.com") .finish().unwrap(); - let resp: HttpResponse = HTTPOk.into(); + let resp: HttpResponse = HttpOk.into(); let resp = cors.response(&mut req, resp).unwrap().response(); assert_eq!( &b"https://www.example.com"[..], diff --git a/src/middleware/csrf.rs b/src/middleware/csrf.rs new file mode 100644 index 000000000..5385f5d4d --- /dev/null +++ b/src/middleware/csrf.rs @@ -0,0 +1,266 @@ +//! A filter for cross-site request forgery (CSRF). +//! +//! This middleware is stateless and [based on request +//! headers](https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)_Prevention_Cheat_Sheet#Verifying_Same_Origin_with_Standard_Headers). +//! +//! By default requests are allowed only if one of these is true: +//! +//! * The request method is safe (`GET`, `HEAD`, `OPTIONS`). It is the +//! applications responsibility to ensure these methods cannot be used to +//! execute unwanted actions. Note that upgrade requests for websockets are +//! also considered safe. +//! * The `Origin` header (added automatically by the browser) matches one +//! of the allowed origins. +//! * There is no `Origin` header but the `Referer` header matches one of +//! the allowed origins. +//! +//! Use [`CsrfFilterBuilder::allow_xhr()`](struct.CsrfFilterBuilder.html#method.allow_xhr) +//! if you want to allow requests with unsafe methods via +//! [CORS](../cors/struct.Cors.html). +//! +//! # Example +//! +//! ``` +//! # extern crate actix_web; +//! # use actix_web::*; +//! +//! use actix_web::middleware::csrf; +//! +//! fn handle_post(_req: HttpRequest) -> &'static str { +//! "This action should only be triggered with requests from the same site" +//! } +//! +//! fn main() { +//! let app = Application::new() +//! .middleware( +//! csrf::CsrfFilter::build() +//! .allowed_origin("https://www.example.com") +//! .finish()) +//! .resource("/", |r| { +//! r.method(Method::GET).f(|_| httpcodes::HttpOk); +//! r.method(Method::POST).f(handle_post); +//! }) +//! .finish(); +//! } +//! ``` +//! +//! In this example the entire application is protected from CSRF. + +use std::borrow::Cow; +use std::collections::HashSet; + +use bytes::Bytes; +use error::{Result, ResponseError}; +use http::{HeaderMap, HttpTryFrom, Uri, header}; +use httprequest::HttpRequest; +use httpresponse::HttpResponse; +use httpmessage::HttpMessage; +use httpcodes::HttpForbidden; +use middleware::{Middleware, Started}; + +/// Potential cross-site request forgery detected. +#[derive(Debug, Fail)] +pub enum CsrfError { + /// The HTTP request header `Origin` was required but not provided. + #[fail(display="Origin header required")] + MissingOrigin, + /// The HTTP request header `Origin` could not be parsed correctly. + #[fail(display="Could not parse Origin header")] + BadOrigin, + /// The cross-site request was denied. + #[fail(display="Cross-site request denied")] + CsrDenied, +} + +impl ResponseError for CsrfError { + fn error_response(&self) -> HttpResponse { + HttpForbidden.build().body(self.to_string()).unwrap() + } +} + +fn uri_origin(uri: &Uri) -> Option { + match (uri.scheme_part(), uri.host(), uri.port()) { + (Some(scheme), Some(host), Some(port)) => { + Some(format!("{}://{}:{}", scheme, host, port)) + } + (Some(scheme), Some(host), None) => { + Some(format!("{}://{}", scheme, host)) + } + _ => None + } +} + +fn origin(headers: &HeaderMap) -> Option, CsrfError>> { + headers.get(header::ORIGIN) + .map(|origin| { + origin + .to_str() + .map_err(|_| CsrfError::BadOrigin) + .map(|o| o.into()) + }) + .or_else(|| { + headers.get(header::REFERER) + .map(|referer| { + Uri::try_from(Bytes::from(referer.as_bytes())) + .ok() + .as_ref() + .and_then(uri_origin) + .ok_or(CsrfError::BadOrigin) + .map(|o| o.into()) + }) + }) +} + +/// A middleware that filters cross-site requests. +pub struct CsrfFilter { + origins: HashSet, + allow_xhr: bool, + allow_missing_origin: bool, +} + +impl CsrfFilter { + /// Start building a `CsrfFilter`. + pub fn build() -> CsrfFilterBuilder { + CsrfFilterBuilder { + cors: CsrfFilter { + origins: HashSet::new(), + allow_xhr: false, + allow_missing_origin: false, + } + } + } + + fn validate(&self, req: &mut HttpRequest) -> Result<(), CsrfError> { + if req.method().is_safe() || (self.allow_xhr && req.headers().contains_key("x-requested-with")) { + Ok(()) + } else if let Some(header) = origin(req.headers()) { + match header { + Ok(ref origin) if self.origins.contains(origin.as_ref()) => Ok(()), + Ok(_) => Err(CsrfError::CsrDenied), + Err(err) => Err(err), + } + } else if self.allow_missing_origin { + Ok(()) + } else { + Err(CsrfError::MissingOrigin) + } + } +} + +impl Middleware for CsrfFilter { + fn start(&self, req: &mut HttpRequest) -> Result { + self.validate(req)?; + Ok(Started::Done) + } +} + +/// Used to build a `CsrfFilter`. +/// +/// To construct a CSRF filter: +/// +/// 1. Call [`CsrfFilter::build`](struct.CsrfFilter.html#method.build) to +/// start building. +/// 2. [Add](struct.CsrfFilterBuilder.html#method.allowed_origin) allowed +/// origins. +/// 3. Call [finish](struct.CsrfFilterBuilder.html#method.finish) to retrieve +/// the constructed filter. +/// +/// # Example +/// +/// ``` +/// use actix_web::middleware::csrf; +/// +/// let csrf = csrf::CsrfFilter::build() +/// .allowed_origin("https://www.example.com") +/// .finish(); +/// ``` +pub struct CsrfFilterBuilder { + cors: CsrfFilter, +} + +impl CsrfFilterBuilder { + /// Add an origin that is allowed to make requests. Will be verified + /// against the `Origin` request header. + pub fn allowed_origin(mut self, origin: &str) -> CsrfFilterBuilder { + self.cors.origins.insert(origin.to_owned()); + self + } + + /// Allow all requests with an `X-Requested-With` header. + /// + /// A cross-site attacker should not be able to send requests with custom + /// headers unless a CORS policy whitelists them. Therefore it should be + /// safe to allow requests with an `X-Requested-With` header (added + /// automatically by many JavaScript libraries). + /// + /// This is disabled by default, because in Safari it is possible to + /// circumvent this using redirects and Flash. + /// + /// Use this method to enable more lax filtering. + pub fn allow_xhr(mut self) -> CsrfFilterBuilder { + self.cors.allow_xhr = true; + self + } + + /// Allow requests if the expected `Origin` header is missing (and + /// there is no `Referer` to fall back on). + /// + /// The filter is conservative by default, but it should be safe to allow + /// missing `Origin` headers because a cross-site attacker cannot prevent + /// the browser from sending `Origin` on unsafe requests. + pub fn allow_missing_origin(mut self) -> CsrfFilterBuilder { + self.cors.allow_missing_origin = true; + self + } + + /// Finishes building the `CsrfFilter` instance. + pub fn finish(self) -> CsrfFilter { + self.cors + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::Method; + use test::TestRequest; + + #[test] + fn test_safe() { + let csrf = CsrfFilter::build() + .allowed_origin("https://www.example.com") + .finish(); + + let mut req = TestRequest::with_header("Origin", "https://www.w3.org") + .method(Method::HEAD) + .finish(); + + assert!(csrf.start(&mut req).is_ok()); + } + + #[test] + fn test_csrf() { + let csrf = CsrfFilter::build() + .allowed_origin("https://www.example.com") + .finish(); + + let mut req = TestRequest::with_header("Origin", "https://www.w3.org") + .method(Method::POST) + .finish(); + + assert!(csrf.start(&mut req).is_err()); + } + + #[test] + fn test_referer() { + let csrf = CsrfFilter::build() + .allowed_origin("https://www.example.com") + .finish(); + + let mut req = TestRequest::with_header("Referer", "https://www.example.com/some/path?query=param") + .method(Method::POST) + .finish(); + + assert!(csrf.start(&mut req).is_ok()); + } +} diff --git a/src/middleware/defaultheaders.rs b/src/middleware/defaultheaders.rs index 4e4686ca8..0dfd38511 100644 --- a/src/middleware/defaultheaders.rs +++ b/src/middleware/defaultheaders.rs @@ -22,8 +22,8 @@ use middleware::{Response, Middleware}; /// .header("X-Version", "0.2") /// .finish()) /// .resource("/test", |r| { -/// r.method(Method::GET).f(|_| httpcodes::HTTPOk); -/// r.method(Method::HEAD).f(|_| httpcodes::HTTPMethodNotAllowed); +/// r.method(Method::GET).f(|_| httpcodes::HttpOk); +/// r.method(Method::HEAD).f(|_| httpcodes::HttpMethodNotAllowed); /// }) /// .finish(); /// } @@ -95,7 +95,7 @@ impl DefaultHeadersBuilder { /// Finishes building and returns the built `DefaultHeaders` middleware. pub fn finish(&mut self) -> DefaultHeaders { let headers = self.headers.take().expect("cannot reuse middleware builder"); - DefaultHeaders{ ct: self.ct, headers: headers } + DefaultHeaders{ ct: self.ct, headers } } } diff --git a/src/middleware/logger.rs b/src/middleware/logger.rs index 4907b214c..f5f2e270b 100644 --- a/src/middleware/logger.rs +++ b/src/middleware/logger.rs @@ -8,6 +8,7 @@ use time; use regex::Regex; use error::Result; +use httpmessage::HttpMessage; use httprequest::HttpRequest; use httpresponse::HttpResponse; use middleware::{Middleware, Started, Finished}; diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 4270c477b..cfda04b9e 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -9,6 +9,7 @@ mod logger; mod session; mod defaultheaders; pub mod cors; +pub mod csrf; pub use self::logger::Logger; pub use self::defaultheaders::{DefaultHeaders, DefaultHeadersBuilder}; pub use self::session::{RequestSession, Session, SessionImpl, SessionBackend, SessionStorage, diff --git a/src/middleware/session.rs b/src/middleware/session.rs index ad865669f..e08ce03f2 100644 --- a/src/middleware/session.rs +++ b/src/middleware/session.rs @@ -1,6 +1,3 @@ -#![allow(dead_code, unused_imports, unused_variables)] - -use std::any::Any; use std::rc::Rc; use std::sync::Arc; use std::marker::PhantomData; @@ -49,8 +46,7 @@ impl RequestSession for HttpRequest { return Session(s.0.as_mut()) } } - //Session(&mut DUMMY) - unreachable!() + Session(unsafe{&mut DUMMY}) } } @@ -90,7 +86,7 @@ impl<'a> Session<'a> { } /// Set a `value` from the session. - pub fn set(&'a mut self, key: &str, value: T) -> Result<()> { + pub fn set(&mut self, key: &str, value: T) -> Result<()> { self.0.set(key, serde_json::to_string(&value)?); Ok(()) } @@ -195,15 +191,13 @@ pub trait SessionBackend: Sized + 'static { /// Dummy session impl, does not do anything struct DummySessionImpl; -static DUMMY: DummySessionImpl = DummySessionImpl; +static mut DUMMY: DummySessionImpl = DummySessionImpl; impl SessionImpl for DummySessionImpl { - fn get(&self, key: &str) -> Option<&str> { - None - } - fn set(&mut self, key: &str, value: String) {} - fn remove(&mut self, key: &str) {} + fn get(&self, _: &str) -> Option<&str> { None } + fn set(&mut self, _: &str, _: String) {} + fn remove(&mut self, _: &str) {} fn clear(&mut self) {} fn write(&self, resp: HttpResponse) -> Result { Ok(Response::Done(resp)) @@ -377,8 +371,8 @@ impl SessionBackend for CookieSessionBackend { FutOk( CookieSession { changed: false, - state: state, inner: Rc::clone(&self.0), + state, }) } } diff --git a/src/multipart.rs b/src/multipart.rs index 9da15ba59..6211f6116 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -9,12 +9,11 @@ use httparse; use bytes::Bytes; use http::HttpTryFrom; use http::header::{self, HeaderMap, HeaderName, HeaderValue}; -use futures::{Async, Future, Stream, Poll}; +use futures::{Async, Stream, Poll}; use futures::task::{Task, current as current_task}; use error::{ParseError, PayloadError, MultipartError}; -use payload::Payload; -use httprequest::HttpRequest; +use payload::PayloadHelper; const MAX_HEADERS: usize = 32; @@ -24,27 +23,24 @@ const MAX_HEADERS: usize = 32; /// Stream implementation. /// `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` /// is used for nested multipart streams. -#[derive(Debug)] -pub struct Multipart { +pub struct Multipart { safety: Safety, error: Option, - inner: Option>>, + inner: Option>>>, } /// -#[derive(Debug)] -pub enum MultipartItem { +pub enum MultipartItem { /// Multipart field - Field(Field), + Field(Field), /// Nested multipart stream - Nested(Multipart), + Nested(Multipart), } -#[derive(Debug)] -enum InnerMultipartItem { +enum InnerMultipartItem { None, - Field(Rc>), - Multipart(Rc>), + Field(Rc>>), + Multipart(Rc>>), } #[derive(PartialEq, Debug)] @@ -59,57 +55,14 @@ enum InnerState { Headers, } -#[derive(Debug)] -struct InnerMultipart { - payload: PayloadRef, +struct InnerMultipart { + payload: PayloadRef, boundary: String, state: InnerState, - item: InnerMultipartItem, + item: InnerMultipartItem, } -impl Multipart { - - /// Create multipart instance for boundary. - pub fn new(boundary: String, payload: Payload) -> Multipart { - Multipart { - error: None, - safety: Safety::new(), - inner: Some(Rc::new(RefCell::new( - InnerMultipart { - payload: PayloadRef::new(payload), - boundary: boundary, - state: InnerState::FirstBoundary, - item: InnerMultipartItem::None, - }))) - } - } - - /// Create multipart instance for request. - pub fn from_request(req: &mut HttpRequest) -> Multipart { - match Multipart::boundary(req.headers()) { - Ok(boundary) => Multipart::new(boundary, req.payload().clone()), - Err(err) => - Multipart { - error: Some(err), - safety: Safety::new(), - inner: None, - } - } - } - - // /// Create multipart instance for client response. - // pub fn from_response(resp: &mut ClientResponse) -> Multipart { - // match Multipart::boundary(resp.headers()) { - // Ok(boundary) => Multipart::new(boundary, resp.payload().clone()), - // Err(err) => - // Multipart { - // error: Some(err), - // safety: Safety::new(), - // inner: None, - // } - // } - // } - +impl Multipart<()> { /// Extract boundary info from headers. pub fn boundary(headers: &HeaderMap) -> Result { if let Some(content_type) = headers.get(header::CONTENT_TYPE) { @@ -132,8 +85,34 @@ impl Multipart { } } -impl Stream for Multipart { - type Item = MultipartItem; +impl Multipart where S: Stream { + + /// Create multipart instance for boundary. + pub fn new(boundary: Result, stream: S) -> Multipart { + match boundary { + Ok(boundary) => Multipart { + error: None, + safety: Safety::new(), + inner: Some(Rc::new(RefCell::new( + InnerMultipart { + boundary, + payload: PayloadRef::new(PayloadHelper::new(stream)), + state: InnerState::FirstBoundary, + item: InnerMultipartItem::None, + }))) + }, + Err(err) => + Multipart { + error: Some(err), + safety: Safety::new(), + inner: None, + } + } + } +} + +impl Stream for Multipart where S: Stream { + type Item = MultipartItem; type Error = MultipartError; fn poll(&mut self) -> Poll, Self::Error> { @@ -147,13 +126,14 @@ impl Stream for Multipart { } } -impl InnerMultipart { +impl InnerMultipart where S: Stream { - fn read_headers(payload: &mut Payload) -> Poll + fn read_headers(payload: &mut PayloadHelper) -> Poll { - match payload.readuntil(b"\r\n\r\n").poll()? { + match payload.readuntil(b"\r\n\r\n")? { Async::NotReady => Ok(Async::NotReady), - Async::Ready(bytes) => { + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(bytes)) => { let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS]; match httparse::parse_headers(&bytes, &mut hdrs) { Ok(httparse::Status::Complete((_, hdrs))) => { @@ -179,12 +159,14 @@ impl InnerMultipart { } } - fn read_boundary(payload: &mut Payload, boundary: &str) -> Poll + fn read_boundary(payload: &mut PayloadHelper, boundary: &str) + -> Poll { // TODO: need to read epilogue - match payload.readline().poll()? { + match payload.readline()? { Async::NotReady => Ok(Async::NotReady), - Async::Ready(chunk) => { + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(chunk)) => { if chunk.len() == boundary.len() + 4 && &chunk[..2] == b"--" && &chunk[2..boundary.len()+2] == boundary.as_bytes() @@ -203,39 +185,42 @@ impl InnerMultipart { } } - fn skip_until_boundary(payload: &mut Payload, boundary: &str) -> Poll + fn skip_until_boundary(payload: &mut PayloadHelper, boundary: &str) + -> Poll { let mut eof = false; loop { - if let Async::Ready(chunk) = payload.readline().poll()? { - if chunk.is_empty() { - //ValueError("Could not find starting boundary %r" - //% (self._boundary)) - } - if chunk.len() < boundary.len() { - continue - } - if &chunk[..2] == b"--" && &chunk[2..chunk.len()-2] == boundary.as_bytes() { - break; - } else { - if chunk.len() < boundary.len() + 2{ + match payload.readline()? { + Async::Ready(Some(chunk)) => { + if chunk.is_empty() { + //ValueError("Could not find starting boundary %r" + //% (self._boundary)) + } + if chunk.len() < boundary.len() { continue } - let b: &[u8] = boundary.as_ref(); - if &chunk[..boundary.len()] == b && - &chunk[boundary.len()..boundary.len()+2] == b"--" { - eof = true; - break; + if &chunk[..2] == b"--" && &chunk[2..chunk.len()-2] == boundary.as_bytes() { + break; + } else { + if chunk.len() < boundary.len() + 2{ + continue } - } - } else { - return Ok(Async::NotReady) + let b: &[u8] = boundary.as_ref(); + if &chunk[..boundary.len()] == b && + &chunk[boundary.len()..boundary.len()+2] == b"--" { + eof = true; + break; + } + } + }, + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(None) => return Err(MultipartError::Incomplete), } } Ok(Async::Ready(eof)) } - fn poll(&mut self, safety: &Safety) -> Poll, MultipartError> { + fn poll(&mut self, safety: &Safety) -> Poll>, MultipartError> { if self.state == InnerState::Eof { Ok(Async::Ready(None)) } else { @@ -247,25 +232,18 @@ impl InnerMultipart { let stop = match self.item { InnerMultipartItem::Field(ref mut field) => { match field.borrow_mut().poll(safety)? { - Async::NotReady => { - return Ok(Async::NotReady) - } - Async::Ready(Some(_)) => - continue, - Async::Ready(None) => - true, + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(Some(_)) => continue, + Async::Ready(None) => true, } - } + }, InnerMultipartItem::Multipart(ref mut multipart) => { match multipart.borrow_mut().poll(safety)? { - Async::NotReady => - return Ok(Async::NotReady), - Async::Ready(Some(_)) => - continue, - Async::Ready(None) => - true, + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(Some(_)) => continue, + Async::Ready(None) => true, } - } + }, _ => false, }; if stop { @@ -281,25 +259,22 @@ impl InnerMultipart { match self.state { // read until first boundary InnerState::FirstBoundary => { - if let Async::Ready(eof) = - InnerMultipart::skip_until_boundary(payload, &self.boundary)? - { - if eof { - self.state = InnerState::Eof; - return Ok(Async::Ready(None)); - } else { - self.state = InnerState::Headers; - } - } else { - return Ok(Async::NotReady) + match InnerMultipart::skip_until_boundary(payload, &self.boundary)? { + Async::Ready(eof) => { + if eof { + self.state = InnerState::Eof; + return Ok(Async::Ready(None)); + } else { + self.state = InnerState::Headers; + } + }, + Async::NotReady => return Ok(Async::NotReady), } - } + }, // read boundary InnerState::Boundary => { match InnerMultipart::read_boundary(payload, &self.boundary)? { - Async::NotReady => { - return Ok(Async::NotReady) - } + Async::NotReady => return Ok(Async::NotReady), Async::Ready(eof) => { if eof { self.state = InnerState::Eof; @@ -375,7 +350,7 @@ impl InnerMultipart { } } -impl Drop for InnerMultipart { +impl Drop for InnerMultipart { fn drop(&mut self) { // InnerMultipartItem::Field has to be dropped first because of Safety. self.item = InnerMultipartItem::None; @@ -383,23 +358,18 @@ impl Drop for InnerMultipart { } /// A single field in a multipart stream -pub struct Field { +pub struct Field { ct: mime::Mime, headers: HeaderMap, - inner: Rc>, + inner: Rc>>, safety: Safety, } -impl Field { +impl Field where S: Stream { fn new(safety: Safety, headers: HeaderMap, - ct: mime::Mime, inner: Rc>) -> Self { - Field { - ct: ct, - headers: headers, - inner: inner, - safety: safety, - } + ct: mime::Mime, inner: Rc>>) -> Self { + Field {ct, headers, inner, safety} } pub fn headers(&self) -> &HeaderMap { @@ -411,7 +381,7 @@ impl Field { } } -impl Stream for Field { +impl Stream for Field where S: Stream { type Item = Bytes; type Error = MultipartError; @@ -424,7 +394,7 @@ impl Stream for Field { } } -impl fmt::Debug for Field { +impl fmt::Debug for Field { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let res = write!(f, "\nMultipartField: {}\n", self.ct); let _ = write!(f, " boundary: {}\n", self.inner.borrow().boundary); @@ -441,18 +411,17 @@ impl fmt::Debug for Field { } } -#[derive(Debug)] -struct InnerField { - payload: Option, +struct InnerField { + payload: Option>, boundary: String, eof: bool, length: Option, } -impl InnerField { +impl InnerField where S: Stream { - fn new(payload: PayloadRef, boundary: String, headers: &HeaderMap) - -> Result + fn new(payload: PayloadRef, boundary: String, headers: &HeaderMap) + -> Result, PayloadError> { let len = if let Some(len) = headers.get(header::CONTENT_LENGTH) { if let Ok(s) = len.to_str() { @@ -469,22 +438,23 @@ impl InnerField { }; Ok(InnerField { + boundary, payload: Some(payload), - boundary: boundary, eof: false, length: len }) } /// Reads body part content chunk of the specified size. /// The body part must has `Content-Length` header with proper value. - fn read_len(payload: &mut Payload, size: &mut u64) -> Poll, MultipartError> + fn read_len(payload: &mut PayloadHelper, size: &mut u64) + -> Poll, MultipartError> { if *size == 0 { Ok(Async::Ready(None)) } else { - match payload.readany().poll() { + match payload.readany() { Ok(Async::NotReady) => Ok(Async::NotReady), - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::Ready(None)) => Err(MultipartError::Incomplete), Ok(Async::Ready(Some(mut chunk))) => { let len = cmp::min(chunk.len() as u64, *size); *size -= len; @@ -501,23 +471,26 @@ impl InnerField { /// Reads content chunk of body part with unknown length. /// The `Content-Length` header for body part is not necessary. - fn read_stream(payload: &mut Payload, boundary: &str) -> Poll, MultipartError> + fn read_stream(payload: &mut PayloadHelper, boundary: &str) + -> Poll, MultipartError> { - match payload.readuntil(b"\r").poll()? { + match payload.readuntil(b"\r")? { Async::NotReady => Ok(Async::NotReady), - Async::Ready(mut chunk) => { + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(mut chunk)) => { if chunk.len() == 1 { payload.unread_data(chunk); - match payload.readexactly(boundary.len() + 4).poll()? { + match payload.readexactly(boundary.len() + 4)? { Async::NotReady => Ok(Async::NotReady), - Async::Ready(chunk) => { + Async::Ready(None) => Err(MultipartError::Incomplete), + Async::Ready(Some(chunk)) => { if &chunk[..2] == b"\r\n" && &chunk[2..4] == b"--" && &chunk[4..] == boundary.as_bytes() { - payload.unread_data(chunk); + payload.unread_data(chunk.freeze()); Ok(Async::Ready(None)) } else { - Ok(Async::Ready(Some(chunk))) + Ok(Async::Ready(Some(chunk.freeze()))) } } } @@ -535,24 +508,6 @@ impl InnerField { if self.payload.is_none() { return Ok(Async::Ready(None)) } - if self.eof { - if let Some(payload) = self.payload.as_ref().unwrap().get_mut(s) { - match payload.readline().poll()? { - Async::NotReady => - return Ok(Async::NotReady), - Async::Ready(chunk) => { - assert_eq!( - chunk.as_ref(), b"\r\n", - "reader did not read all the data or it is malformed"); - } - } - } else { - return Ok(Async::NotReady); - } - - self.payload.take(); - return Ok(Async::Ready(None)) - } let result = if let Some(payload) = self.payload.as_ref().unwrap().get_mut(s) { let res = if let Some(ref mut len) = self.length { @@ -566,12 +521,13 @@ impl InnerField { Async::Ready(Some(bytes)) => Async::Ready(Some(bytes)), Async::Ready(None) => { self.eof = true; - match payload.readline().poll()? { + match payload.readline()? { Async::NotReady => Async::NotReady, - Async::Ready(chunk) => { - assert_eq!( - chunk.as_ref(), b"\r\n", - "reader did not read all the data or it is malformed"); + Async::Ready(None) => Async::Ready(None), + Async::Ready(Some(line)) => { + if line.as_ref() != b"\r\n" { + warn!("multipart field did not read all the data or it is malformed"); + } Async::Ready(None) } } @@ -588,25 +544,22 @@ impl InnerField { } } -#[derive(Debug)] -struct PayloadRef { - task: Option, - payload: Rc, +struct PayloadRef { + payload: Rc>, } -impl PayloadRef { - fn new(payload: Payload) -> PayloadRef { +impl PayloadRef where S: Stream { + fn new(payload: PayloadHelper) -> PayloadRef { PayloadRef { - task: None, payload: Rc::new(payload), } } - fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<&'a mut Payload> + fn get_mut<'a, 'b>(&'a self, s: &'b Safety) -> Option<&'a mut PayloadHelper> where 'a: 'b { if s.current() { - let payload: &mut Payload = unsafe { + let payload: &mut PayloadHelper = unsafe { &mut *(self.payload.as_ref() as *const _ as *mut _)}; Some(payload) } else { @@ -615,10 +568,9 @@ impl PayloadRef { } } -impl Clone for PayloadRef { - fn clone(&self) -> PayloadRef { +impl Clone for PayloadRef { + fn clone(&self) -> PayloadRef { PayloadRef { - task: Some(current_task()), payload: Rc::clone(&self.payload), } } @@ -639,7 +591,7 @@ impl Safety { Safety { task: None, level: Rc::strong_count(&payload), - payload: payload, + payload, } } @@ -655,7 +607,7 @@ impl Clone for Safety { Safety { task: Some(current_task()), level: Rc::strong_count(&payload), - payload: payload, + payload, } } } @@ -733,7 +685,7 @@ mod tests { sender.feed_data(bytes); let mut multipart = Multipart::new( - "abbc761f78ff4d7cb7573b5a23f96ef0".to_owned(), payload); + Ok("abbc761f78ff4d7cb7573b5a23f96ef0".to_owned()), payload); match multipart.poll() { Ok(Async::Ready(Some(item))) => { match item { diff --git a/src/payload.rs b/src/payload.rs index 97e59a488..3cefcf718 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -1,40 +1,18 @@ //! Payload stream -use std::{fmt, cmp}; +use std::cmp; use std::rc::{Rc, Weak}; use std::cell::RefCell; use std::collections::VecDeque; -use std::ops::{Deref, DerefMut}; use bytes::{Bytes, BytesMut}; -use futures::{Future, Async, Poll, Stream}; -use futures::task::{Task, current as current_task}; +use futures::{Async, Poll, Stream}; -use body::BodyStream; use error::PayloadError; -pub(crate) const DEFAULT_BUFFER_SIZE: usize = 65_536; // max buffer size 64k - -/// Just Bytes object -#[derive(PartialEq, Message)] -pub struct PayloadItem(pub Bytes); - -impl Deref for PayloadItem { - type Target = Bytes; - - fn deref(&self) -> &Bytes { - &self.0 - } -} - -impl DerefMut for PayloadItem { - fn deref_mut(&mut self) -> &mut Bytes { - &mut self.0 - } -} - -impl fmt::Debug for PayloadItem { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(&self.0, f) - } +#[derive(Debug, PartialEq)] +pub(crate) enum PayloadStatus { + Read, + Pause, + Dropped, } /// Buffered stream of bytes chunks @@ -88,68 +66,25 @@ impl Payload { self.inner.borrow().len() == 0 } - /// Get first available chunk of data. - #[inline] - pub fn readany(&self) -> ReadAny { - ReadAny(Rc::clone(&self.inner)) - } - - /// Get exact number of bytes - #[inline] - pub fn readexactly(&self, size: usize) -> ReadExactly { - ReadExactly(Rc::clone(&self.inner), size) - } - - /// Read until `\n` - #[inline] - pub fn readline(&self) -> ReadLine { - ReadLine(Rc::clone(&self.inner)) - } - - /// Read until match line - #[inline] - pub fn readuntil(&self, line: &[u8]) -> ReadUntil { - ReadUntil(Rc::clone(&self.inner), line.to_vec()) - } - - #[doc(hidden)] - #[inline] - pub fn readall(&self) -> Option { - self.inner.borrow_mut().readall() - } - /// Put unused data back to payload #[inline] pub fn unread_data(&mut self, data: Bytes) { self.inner.borrow_mut().unread_data(data); } - /// Get size of payload buffer - #[inline] - pub fn buffer_size(&self) -> usize { - self.inner.borrow().buffer_size() - } - - /// Set size of payload buffer - #[inline] - pub fn set_buffer_size(&self, size: usize) { - self.inner.borrow_mut().set_buffer_size(size) - } - - /// Convert payload into compatible `HttpResponse` body stream - #[inline] - pub fn stream(self) -> BodyStream { - Box::new(self.map(|i| i.0).map_err(|e| e.into())) + #[cfg(test)] + pub(crate) fn readall(&self) -> Option { + self.inner.borrow_mut().readall() } } impl Stream for Payload { - type Item = PayloadItem; + type Item = Bytes; type Error = PayloadError; #[inline] - fn poll(&mut self) -> Poll, PayloadError> { - self.inner.borrow_mut().readany(false) + fn poll(&mut self) -> Poll, PayloadError> { + self.inner.borrow_mut().readany() } } @@ -159,69 +94,8 @@ impl Clone for Payload { } } -/// Get first available chunk of data -pub struct ReadAny(Rc>); - -impl Stream for ReadAny { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll, Self::Error> { - match self.0.borrow_mut().readany(false)? { - Async::Ready(Some(item)) => Ok(Async::Ready(Some(item.0))), - Async::Ready(None) => Ok(Async::Ready(None)), - Async::NotReady => Ok(Async::NotReady), - } - } -} - -/// Get exact number of bytes -pub struct ReadExactly(Rc>, usize); - -impl Future for ReadExactly { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll { - match self.0.borrow_mut().readexactly(self.1, false)? { - Async::Ready(chunk) => Ok(Async::Ready(chunk)), - Async::NotReady => Ok(Async::NotReady), - } - } -} - -/// Read until `\n` -pub struct ReadLine(Rc>); - -impl Future for ReadLine { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll { - match self.0.borrow_mut().readline(false)? { - Async::Ready(chunk) => Ok(Async::Ready(chunk)), - Async::NotReady => Ok(Async::NotReady), - } - } -} - -/// Read until match line -pub struct ReadUntil(Rc>, Vec); - -impl Future for ReadUntil { - type Item = Bytes; - type Error = PayloadError; - - fn poll(&mut self) -> Poll { - match self.0.borrow_mut().readuntil(&self.1, false)? { - Async::Ready(chunk) => Ok(Async::Ready(chunk)), - Async::NotReady => Ok(Async::NotReady), - } - } -} - /// Payload writer interface. -pub trait PayloadWriter { +pub(crate) trait PayloadWriter { /// Set stream error. fn set_error(&mut self, err: PayloadError); @@ -232,8 +106,8 @@ pub trait PayloadWriter { /// Feed bytes into a payload stream fn feed_data(&mut self, data: Bytes); - /// Get estimated available capacity - fn capacity(&self) -> usize; + /// Need read data + fn need_read(&self) -> PayloadStatus; } /// Sender part of the payload stream @@ -262,59 +136,54 @@ impl PayloadWriter for PayloadSender { } #[inline] - fn capacity(&self) -> usize { + fn need_read(&self) -> PayloadStatus { + // we check need_read only if Payload (other side) is alive, + // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { - shared.borrow().capacity() + if shared.borrow().need_read { + PayloadStatus::Read + } else { + PayloadStatus::Pause + } } else { - 0 + PayloadStatus::Dropped } } } - #[derive(Debug)] struct Inner { len: usize, eof: bool, err: Option, - task: Option, + need_read: bool, items: VecDeque, - buf_size: usize, } impl Inner { fn new(eof: bool) -> Self { Inner { + eof, len: 0, - eof: eof, err: None, - task: None, items: VecDeque::new(), - buf_size: DEFAULT_BUFFER_SIZE, + need_read: true, } } fn set_error(&mut self, err: PayloadError) { self.err = Some(err); - if let Some(task) = self.task.take() { - task.notify() - } } fn feed_eof(&mut self) { self.eof = true; - if let Some(task) = self.task.take() { - task.notify() - } } fn feed_data(&mut self, data: Bytes) { self.len += data.len(); + self.need_read = false; self.items.push_back(data); - if let Some(task) = self.task.take() { - task.notify() - } } fn eof(&self) -> bool { @@ -325,23 +194,87 @@ impl Inner { self.len } - fn readany(&mut self, notify: bool) -> Poll, PayloadError> { + #[cfg(test)] + pub(crate) fn readall(&mut self) -> Option { + let len = self.items.iter().map(|b| b.len()).sum(); + if len > 0 { + let mut buf = BytesMut::with_capacity(len); + for item in &self.items { + buf.extend_from_slice(item); + } + self.items = VecDeque::new(); + self.len = 0; + Some(buf.take().freeze()) + } else { + self.need_read = true; + None + } + } + + fn readany(&mut self) -> Poll, PayloadError> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); - Ok(Async::Ready(Some(PayloadItem(data)))) + Ok(Async::Ready(Some(data))) } else if let Some(err) = self.err.take() { Err(err) } else if self.eof { Ok(Async::Ready(None)) } else { - if notify { - self.task = Some(current_task()); - } + self.need_read = true; Ok(Async::NotReady) } } - fn readexactly(&mut self, size: usize, notify: bool) -> Result, PayloadError> { + fn unread_data(&mut self, data: Bytes) { + self.len += data.len(); + self.items.push_front(data); + } +} + +pub struct PayloadHelper { + len: usize, + items: VecDeque, + stream: S, +} + +impl PayloadHelper where S: Stream { + + pub fn new(stream: S) -> Self { + PayloadHelper { + len: 0, + items: VecDeque::new(), + stream, + } + } + + fn poll_stream(&mut self) -> Poll { + self.stream.poll().map(|res| { + match res { + Async::Ready(Some(data)) => { + self.len += data.len(); + self.items.push_back(data); + Async::Ready(true) + }, + Async::Ready(None) => Async::Ready(false), + Async::NotReady => Async::NotReady, + } + }) + } + + pub fn readany(&mut self) -> Poll, PayloadError> { + if let Some(data) = self.items.pop_front() { + self.len -= data.len(); + Ok(Async::Ready(Some(data))) + } else { + match self.poll_stream()? { + Async::Ready(true) => self.readany(), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + } + + pub fn readexactly(&mut self, size: usize) -> Poll, PayloadError> { if size <= self.len { let mut buf = BytesMut::with_capacity(size); while buf.len() < size { @@ -351,22 +284,40 @@ impl Inner { buf.extend_from_slice(&chunk.split_to(rem)); if !chunk.is_empty() { self.items.push_front(chunk); - return Ok(Async::Ready(buf.freeze())) + } + } + return Ok(Async::Ready(Some(buf))) + } + + match self.poll_stream()? { + Async::Ready(true) => self.readexactly(size), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + + pub fn copy(&mut self, size: usize) -> Poll, PayloadError> { + if size <= self.len { + let mut buf = BytesMut::with_capacity(size); + for chunk in &self.items { + if buf.len() < size { + let rem = cmp::min(size - buf.len(), chunk.len()); + buf.extend_from_slice(&chunk[..rem]); + } + if buf.len() == size { + return Ok(Async::Ready(Some(buf))) } } } - if let Some(err) = self.err.take() { - Err(err) - } else { - if notify { - self.task = Some(current_task()); - } - Ok(Async::NotReady) + match self.poll_stream()? { + Async::Ready(true) => self.copy(size), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), } } - fn readuntil(&mut self, line: &[u8], notify: bool) -> Result, PayloadError> { + pub fn readuntil(&mut self, line: &[u8]) -> Poll, PayloadError> { let mut idx = 0; let mut num = 0; let mut offset = 0; @@ -410,58 +361,33 @@ impl Inner { } } self.len -= length; - return Ok(Async::Ready(buf.freeze())) + return Ok(Async::Ready(Some(buf.freeze()))) } } - if let Some(err) = self.err.take() { - Err(err) - } else { - if notify { - self.task = Some(current_task()); - } - Ok(Async::NotReady) + + match self.poll_stream()? { + Async::Ready(true) => self.readuntil(line), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), } } - fn readline(&mut self, notify: bool) -> Result, PayloadError> { - self.readuntil(b"\n", notify) + pub fn readline(&mut self) -> Poll, PayloadError> { + self.readuntil(b"\n") } - pub fn readall(&mut self) -> Option { - let len = self.items.iter().map(|b| b.len()).sum(); - if len > 0 { - let mut buf = BytesMut::with_capacity(len); - for item in &self.items { - buf.extend_from_slice(item); - } - self.items = VecDeque::new(); - self.len = 0; - Some(buf.take().freeze()) - } else { - None - } - } - - fn unread_data(&mut self, data: Bytes) { + pub fn unread_data(&mut self, data: Bytes) { self.len += data.len(); self.items.push_front(data); } - #[inline] - fn capacity(&self) -> usize { - if self.len > self.buf_size { - 0 - } else { - self.buf_size - self.len - } - } - - fn buffer_size(&self) -> usize { - self.buf_size - } - - fn set_buffer_size(&mut self, size: usize) { - self.buf_size = size + #[allow(dead_code)] + pub fn remaining(&mut self) -> Bytes { + self.items.iter_mut() + .fold(BytesMut::new(), |mut b, c| { + b.extend_from_slice(c); + b + }).freeze() } } @@ -487,11 +413,10 @@ mod tests { fn test_basic() { Core::new().unwrap().run(lazy(|| { let (_, payload) = Payload::new(false); + let mut payload = PayloadHelper::new(payload); - assert!(!payload.eof()); - assert!(payload.is_empty()); - assert_eq!(payload.len(), 0); - assert_eq!(Async::NotReady, payload.readany().poll().ok().unwrap()); + assert_eq!(payload.len, 0); + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); let res: Result<(), ()> = Ok(()); result(res) @@ -502,22 +427,17 @@ mod tests { fn test_eof() { Core::new().unwrap().run(lazy(|| { let (mut sender, payload) = Payload::new(false); + let mut payload = PayloadHelper::new(payload); - assert_eq!(Async::NotReady, payload.readany().poll().ok().unwrap()); - assert!(!payload.eof()); - + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); sender.feed_data(Bytes::from("data")); sender.feed_eof(); - assert!(!payload.eof()); - assert_eq!(Async::Ready(Some(Bytes::from("data"))), - payload.readany().poll().ok().unwrap()); - assert!(payload.is_empty()); - assert!(payload.eof()); - assert_eq!(payload.len(), 0); + payload.readany().ok().unwrap()); + assert_eq!(payload.len, 0); + assert_eq!(Async::Ready(None), payload.readany().ok().unwrap()); - assert_eq!(Async::Ready(None), payload.readany().poll().ok().unwrap()); let res: Result<(), ()> = Ok(()); result(res) })).unwrap(); @@ -527,11 +447,12 @@ mod tests { fn test_err() { Core::new().unwrap().run(lazy(|| { let (mut sender, payload) = Payload::new(false); + let mut payload = PayloadHelper::new(payload); - assert_eq!(Async::NotReady, payload.readany().poll().ok().unwrap()); + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); sender.set_error(PayloadError::Incomplete); - payload.readany().poll().err().unwrap(); + payload.readany().err().unwrap(); let res: Result<(), ()> = Ok(()); result(res) })).unwrap(); @@ -541,20 +462,18 @@ mod tests { fn test_readany() { Core::new().unwrap().run(lazy(|| { let (mut sender, payload) = Payload::new(false); + let mut payload = PayloadHelper::new(payload); sender.feed_data(Bytes::from("line1")); - - assert!(!payload.is_empty()); - assert_eq!(payload.len(), 5); - sender.feed_data(Bytes::from("line2")); - assert!(!payload.is_empty()); - assert_eq!(payload.len(), 10); assert_eq!(Async::Ready(Some(Bytes::from("line1"))), - payload.readany().poll().ok().unwrap()); - assert!(!payload.is_empty()); - assert_eq!(payload.len(), 5); + payload.readany().ok().unwrap()); + assert_eq!(payload.len, 0); + + assert_eq!(Async::Ready(Some(Bytes::from("line2"))), + payload.readany().ok().unwrap()); + assert_eq!(payload.len, 0); let res: Result<(), ()> = Ok(()); result(res) @@ -565,23 +484,23 @@ mod tests { fn test_readexactly() { Core::new().unwrap().run(lazy(|| { let (mut sender, payload) = Payload::new(false); + let mut payload = PayloadHelper::new(payload); - assert_eq!(Async::NotReady, payload.readexactly(2).poll().ok().unwrap()); + assert_eq!(Async::NotReady, payload.readexactly(2).ok().unwrap()); sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line2")); - assert_eq!(payload.len(), 10); - assert_eq!(Async::Ready(Bytes::from("li")), - payload.readexactly(2).poll().ok().unwrap()); - assert_eq!(payload.len(), 8); + assert_eq!(Async::Ready(Some(BytesMut::from("li"))), + payload.readexactly(2).ok().unwrap()); + assert_eq!(payload.len, 3); - assert_eq!(Async::Ready(Bytes::from("ne1l")), - payload.readexactly(4).poll().ok().unwrap()); - assert_eq!(payload.len(), 4); + assert_eq!(Async::Ready(Some(BytesMut::from("ne1l"))), + payload.readexactly(4).ok().unwrap()); + assert_eq!(payload.len, 4); sender.set_error(PayloadError::Incomplete); - payload.readexactly(10).poll().err().unwrap(); + payload.readexactly(10).err().unwrap(); let res: Result<(), ()> = Ok(()); result(res) @@ -592,23 +511,23 @@ mod tests { fn test_readuntil() { Core::new().unwrap().run(lazy(|| { let (mut sender, payload) = Payload::new(false); + let mut payload = PayloadHelper::new(payload); - assert_eq!(Async::NotReady, payload.readuntil(b"ne").poll().ok().unwrap()); + assert_eq!(Async::NotReady, payload.readuntil(b"ne").ok().unwrap()); sender.feed_data(Bytes::from("line1")); sender.feed_data(Bytes::from("line2")); - assert_eq!(payload.len(), 10); - assert_eq!(Async::Ready(Bytes::from("line")), - payload.readuntil(b"ne").poll().ok().unwrap()); - assert_eq!(payload.len(), 6); + assert_eq!(Async::Ready(Some(Bytes::from("line"))), + payload.readuntil(b"ne").ok().unwrap()); + assert_eq!(payload.len, 1); - assert_eq!(Async::Ready(Bytes::from("1line2")), - payload.readuntil(b"2").poll().ok().unwrap()); - assert_eq!(payload.len(), 0); + assert_eq!(Async::Ready(Some(Bytes::from("1line2"))), + payload.readuntil(b"2").ok().unwrap()); + assert_eq!(payload.len, 0); sender.set_error(PayloadError::Incomplete); - payload.readuntil(b"b").poll().err().unwrap(); + payload.readuntil(b"b").err().unwrap(); let res: Result<(), ()> = Ok(()); result(res) @@ -625,7 +544,7 @@ mod tests { assert_eq!(payload.len(), 4); assert_eq!(Async::Ready(Some(Bytes::from("data"))), - payload.readany().poll().ok().unwrap()); + payload.poll().ok().unwrap()); let res: Result<(), ()> = Ok(()); result(res) diff --git a/src/pipeline.rs b/src/pipeline.rs index 2408bf93b..bd7801a36 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -72,7 +72,7 @@ struct PipelineInfo { impl PipelineInfo { fn new(req: HttpRequest) -> PipelineInfo { PipelineInfo { - req: req, + req, count: 0, mws: Rc::new(Vec::new()), error: None, @@ -108,9 +108,8 @@ impl> Pipeline { handler: Rc>) -> Pipeline { let mut info = PipelineInfo { - req: req, + req, mws, count: 0, - mws: mws, error: None, context: None, disconnected: None, @@ -174,7 +173,8 @@ impl> HttpHandlerTask for Pipeline { PipelineState::None => return Ok(Async::Ready(true)), PipelineState::Error => - return Err(io::Error::new(io::ErrorKind::Other, "Internal error").into()), + return Err(io::Error::new( + io::ErrorKind::Other, "Internal error").into()), _ => (), } @@ -307,7 +307,7 @@ impl WaitingResponse { RunMiddlewares::init(info, resp), ReplyItem::Future(fut) => PipelineState::Handler( - WaitingResponse { fut: fut, _s: PhantomData, _h: PhantomData }), + WaitingResponse { fut, _s: PhantomData, _h: PhantomData }), } } @@ -355,7 +355,7 @@ impl RunMiddlewares { }, Ok(Response::Future(fut)) => { return PipelineState::RunMiddlewares( - RunMiddlewares { curr: curr, fut: Some(fut), + RunMiddlewares { curr, fut: Some(fut), _s: PhantomData, _h: PhantomData }) }, }; @@ -444,7 +444,7 @@ impl ProcessResponse { #[inline] fn init(resp: HttpResponse) -> PipelineState { PipelineState::Response( - ProcessResponse{ resp: resp, + ProcessResponse{ resp, iostate: IOState::Response, running: RunningState::Running, drain: None, _s: PhantomData, _h: PhantomData}) @@ -644,7 +644,7 @@ impl FinishingMiddlewares { if info.count == 0 { Completed::init(info) } else { - let mut state = FinishingMiddlewares{resp: resp, fut: None, + let mut state = FinishingMiddlewares{resp, fut: None, _s: PhantomData, _h: PhantomData}; if let Some(st) = state.poll(info) { st diff --git a/src/pred.rs b/src/pred.rs index 47d906fb0..b49d4ec58 100644 --- a/src/pred.rs +++ b/src/pred.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use http; use http::{header, HttpTryFrom}; +use httpmessage::HttpMessage; use httprequest::HttpRequest; /// Trait defines resource route predicate. @@ -27,8 +28,8 @@ pub trait Predicate { /// fn main() { /// Application::new() /// .resource("/index.html", |r| r.route() -/// .p(pred::Any(pred::Get()).or(pred::Post())) -/// .h(HTTPMethodNotAllowed)); +/// .filter(pred::Any(pred::Get()).or(pred::Post())) +/// .h(HttpMethodNotAllowed)); /// } /// ``` pub fn Any + 'static>(pred: P) -> AnyPredicate @@ -70,9 +71,9 @@ impl Predicate for AnyPredicate { /// fn main() { /// Application::new() /// .resource("/index.html", |r| r.route() -/// .p(pred::All(pred::Get()) +/// .filter(pred::All(pred::Get()) /// .and(pred::Header("content-type", "plain/text"))) -/// .h(HTTPMethodNotAllowed)); +/// .h(HttpMethodNotAllowed)); /// } /// ``` pub fn All + 'static>(pred: P) -> AllPredicate { diff --git a/src/resource.rs b/src/resource.rs index eb783a227..2e83225ea 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -81,8 +81,8 @@ impl Resource { /// let app = Application::new() /// .resource( /// "/", |r| r.route() - /// .p(pred::Any(pred::Get()).or(pred::Put())) - /// .p(pred::Header("Content-Type", "text/plain")) + /// .filter(pred::Any(pred::Get()).or(pred::Put())) + /// .filter(pred::Header("Content-Type", "text/plain")) /// .f(|r| HttpResponse::Ok())) /// .finish(); /// } @@ -97,11 +97,11 @@ impl Resource { /// This is shortcut for: /// /// ```rust,ignore - /// Resource::resource("/", |r| r.route().p(pred::Get()).f(index) + /// Resource::resource("/", |r| r.route().filter(pred::Get()).f(index) /// ``` pub fn method(&mut self, method: Method) -> &mut Route { self.routes.push(Route::default()); - self.routes.last_mut().unwrap().p(pred::Method(method)) + self.routes.last_mut().unwrap().filter(pred::Method(method)) } /// Register a new route and add handler object. diff --git a/src/route.rs b/src/route.rs index bd721b1c6..856d6fa85 100644 --- a/src/route.rs +++ b/src/route.rs @@ -8,7 +8,7 @@ use pred::Predicate; use handler::{Reply, ReplyItem, Handler, Responder, RouteHandler, AsyncHandler, WrapHandler}; use middleware::{Middleware, Response as MiddlewareResponse, Started as MiddlewareStarted}; -use httpcodes::HTTPNotFound; +use httpcodes::HttpNotFound; use httprequest::HttpRequest; use httpresponse::HttpResponse; @@ -26,7 +26,7 @@ impl Default for Route { fn default() -> Route { Route { preds: Vec::new(), - handler: InnerHandler::new(|_| HTTPNotFound), + handler: InnerHandler::new(|_| HttpNotFound), } } } @@ -65,18 +65,24 @@ impl Route { /// Application::new() /// .resource("/path", |r| /// r.route() - /// .p(pred::Get()) - /// .p(pred::Header("content-type", "text/plain")) - /// .f(|req| HTTPOk) + /// .filter(pred::Get()) + /// .filter(pred::Header("content-type", "text/plain")) + /// .f(|req| HttpOk) /// ) /// # .finish(); /// # } /// ``` - pub fn p + 'static>(&mut self, p: T) -> &mut Self { + pub fn filter + 'static>(&mut self, p: T) -> &mut Self { self.preds.push(Box::new(p)); self } + #[doc(hidden)] + #[deprecated(since="0.4.1", note="please use `.filter()` instead")] + pub fn p + 'static>(&mut self, p: T) -> &mut Self { + self.filter(p) + } + /// Set handler object. Usually call to this method is last call /// during route configuration, because it does not return reference to self. pub fn h>(&mut self, handler: H) { @@ -179,14 +185,10 @@ impl Compose { mws: Rc>>>, handler: InnerHandler) -> Self { - let mut info = ComposeInfo { - count: 0, - req: req, - mws: mws, - handler: handler }; + let mut info = ComposeInfo { count: 0, req, mws, handler }; let state = StartMiddlewares::init(&mut info); - Compose {state: state, info: info} + Compose {state, info} } } @@ -308,7 +310,7 @@ impl WaitingResponse { RunMiddlewares::init(info, resp), ReplyItem::Future(fut) => ComposeState::Handler( - WaitingResponse { fut: fut, _s: PhantomData }), + WaitingResponse { fut, _s: PhantomData }), } } @@ -353,7 +355,7 @@ impl RunMiddlewares { }, Ok(MiddlewareResponse::Future(fut)) => { return ComposeState::RunMiddlewares( - RunMiddlewares { curr: curr, fut: Some(fut), _s: PhantomData }) + RunMiddlewares { curr, fut: Some(fut), _s: PhantomData }) }, }; } diff --git a/src/router.rs b/src/router.rs index 9f1d93b05..fc01bd3be 100644 --- a/src/router.rs +++ b/src/router.rs @@ -46,13 +46,9 @@ impl Router { } } - let len = prefix.len(); + let prefix_len = prefix.len(); (Router(Rc::new( - Inner{ prefix: prefix, - prefix_len: len, - named: named, - patterns: patterns, - srv: settings })), resources) + Inner{ prefix, prefix_len, named, patterns, srv: settings })), resources) } /// Router prefix @@ -152,7 +148,12 @@ impl Pattern { /// /// Panics if path pattern is wrong. pub fn new(name: &str, path: &str) -> Self { - let (pattern, elements, is_dynamic) = Pattern::parse(path); + Pattern::with_prefix(name, path, "/") + } + + /// Parse path pattern and create new `Pattern` instance with custom prefix + pub fn with_prefix(name: &str, path: &str, prefix: &str) -> Self { + let (pattern, elements, is_dynamic) = Pattern::parse(path, prefix); let tp = if is_dynamic { let re = match Regex::new(&pattern) { @@ -168,10 +169,10 @@ impl Pattern { }; Pattern { - tp: tp, + tp, + pattern, + elements, name: name.into(), - pattern: pattern, - elements: elements, } } @@ -192,7 +193,9 @@ impl Pattern { } } - pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>) -> bool { + pub fn match_with_params<'a>(&'a self, path: &'a str, params: &'a mut Params<'a>) + -> bool + { match self.tp { PatternType::Static(ref s) => s == path, PatternType::Dynamic(ref re, ref names) => { @@ -240,11 +243,11 @@ impl Pattern { Ok(path) } - fn parse(pattern: &str) -> (String, Vec, bool) { + fn parse(pattern: &str, prefix: &str) -> (String, Vec, bool) { const DEFAULT_PATTERN: &str = "[^/]+"; - let mut re1 = String::from("^/"); - let mut re2 = String::from("/"); + let mut re1 = String::from("^") + prefix; + let mut re2 = String::from(prefix); let mut el = String::new(); let mut in_param = false; let mut in_param_pattern = false; diff --git a/src/server/channel.rs b/src/server/channel.rs index 85c3ac4ef..390aaee87 100644 --- a/src/server/channel.rs +++ b/src/server/channel.rs @@ -18,15 +18,6 @@ enum HttpProtocol { Unknown(Rc>, Option, T, BytesMut), } -impl HttpProtocol { - fn is_unknown(&self) -> bool { - match *self { - HttpProtocol::Unknown(_, _, _, _) => true, - _ => false - } - } -} - enum ProtocolKind { Http1, Http2, @@ -41,18 +32,18 @@ pub struct HttpChannel where T: IoStream, H: HttpHandler + 'static { impl HttpChannel where T: IoStream, H: HttpHandler + 'static { pub(crate) fn new(settings: Rc>, - io: T, peer: Option, http2: bool) -> HttpChannel + mut io: T, peer: Option, http2: bool) -> HttpChannel { settings.add_channel(); + let _ = io.set_nodelay(true); + if http2 { HttpChannel { - node: None, - proto: Some(HttpProtocol::H2( + node: None, proto: Some(HttpProtocol::H2( h2::Http2::new(settings, io, peer, Bytes::new()))) } } else { HttpChannel { - node: None, - proto: Some(HttpProtocol::Unknown( + node: None, proto: Some(HttpProtocol::Unknown( settings, peer, io, BytesMut::with_capacity(4096))) } } } @@ -78,15 +69,18 @@ impl Future for HttpChannel where T: IoStream, H: HttpHandler + 'sta type Error = (); fn poll(&mut self) -> Poll { - if !self.proto.as_ref().map(|p| p.is_unknown()).unwrap_or(false) && self.node.is_none() { - self.node = Some(Node::new(self)); - match self.proto { + if !self.node.is_none() { + let el = self as *mut _; + self.node = Some(Node::new(el)); + let _ = match self.proto { Some(HttpProtocol::H1(ref mut h1)) => - h1.settings().head().insert(self.node.as_ref().unwrap()), + self.node.as_ref().map(|n| h1.settings().head().insert(n)), Some(HttpProtocol::H2(ref mut h2)) => - h2.settings().head().insert(self.node.as_ref().unwrap()), - _ => (), - } + self.node.as_ref().map(|n| h2.settings().head().insert(n)), + Some(HttpProtocol::Unknown(ref mut settings, _, _, _)) => + self.node.as_ref().map(|n| settings.head().insert(n)), + None => unreachable!(), + }; } let kind = match self.proto { @@ -95,7 +89,7 @@ impl Future for HttpChannel where T: IoStream, H: HttpHandler + 'sta match result { Ok(Async::Ready(())) | Err(_) => { h1.settings().remove_channel(); - self.node.as_ref().unwrap().remove(); + self.node.as_mut().map(|n| n.remove()); }, _ => (), } @@ -106,7 +100,7 @@ impl Future for HttpChannel where T: IoStream, H: HttpHandler + 'sta match result { Ok(Async::Ready(())) | Err(_) => { h2.settings().remove_channel(); - self.node.as_ref().unwrap().remove(); + self.node.as_mut().map(|n| n.remove()); }, _ => (), } @@ -117,6 +111,7 @@ impl Future for HttpChannel where T: IoStream, H: HttpHandler + 'sta Ok(Async::Ready(0)) | Err(_) => { debug!("Ignored premature client disconnection"); settings.remove_channel(); + self.node.as_mut().map(|n| n.remove()); return Err(()) }, _ => (), @@ -163,11 +158,11 @@ pub(crate) struct Node impl Node { - fn new(el: &mut T) -> Self { + fn new(el: *mut T) -> Self { Node { next: None, prev: None, - element: el as *mut _, + element: el, } } @@ -186,13 +181,14 @@ impl Node } } - fn remove(&self) { - #[allow(mutable_transmutes)] + fn remove(&mut self) { unsafe { - if let Some(ref prev) = self.prev { - let p: &mut Node<()> = mem::transmute(prev.as_ref().unwrap()); - let slf: &mut Node = mem::transmute(self); - p.next = slf.next.take(); + self.element = ptr::null_mut(); + let next = self.next.take(); + let mut prev = self.prev.take(); + + if let Some(ref mut prev) = prev { + prev.as_mut().unwrap().next = next; } } } @@ -237,7 +233,7 @@ pub(crate) struct WrapperStream where T: AsyncRead + AsyncWrite + 'static { impl WrapperStream where T: AsyncRead + AsyncWrite + 'static { pub fn new(io: T) -> Self { - WrapperStream{io: io} + WrapperStream{ io } } } diff --git a/src/server/encoding.rs b/src/server/encoding.rs index 964754ab0..23f4aef7f 100644 --- a/src/server/encoding.rs +++ b/src/server/encoding.rs @@ -11,14 +11,14 @@ use flate2::Compression; use flate2::read::GzDecoder; use flate2::write::{GzEncoder, DeflateDecoder, DeflateEncoder}; use brotli2::write::{BrotliDecoder, BrotliEncoder}; -use bytes::{Bytes, BytesMut, BufMut, Writer}; +use bytes::{Bytes, BytesMut, BufMut}; use headers::ContentEncoding; use body::{Body, Binary}; use error::PayloadError; -use httprequest::HttpMessage; +use httprequest::HttpInnerMessage; use httpresponse::HttpResponse; -use payload::{PayloadSender, PayloadWriter}; +use payload::{PayloadSender, PayloadWriter, PayloadStatus}; use super::shared::SharedBytes; @@ -120,26 +120,84 @@ impl PayloadWriter for PayloadType { } #[inline] - fn capacity(&self) -> usize { + fn need_read(&self) -> PayloadStatus { match *self { - PayloadType::Sender(ref sender) => sender.capacity(), - PayloadType::Encoding(ref enc) => enc.capacity(), + PayloadType::Sender(ref sender) => sender.need_read(), + PayloadType::Encoding(ref enc) => enc.need_read(), } } } -enum Decoder { - Deflate(Box>>), + +/// Payload wrapper with content decompression support +pub(crate) struct EncodedPayload { + inner: PayloadSender, + error: bool, + payload: PayloadStream, +} + +impl EncodedPayload { + pub fn new(inner: PayloadSender, enc: ContentEncoding) -> EncodedPayload { + EncodedPayload{ inner, error: false, payload: PayloadStream::new(enc) } + } +} + +impl PayloadWriter for EncodedPayload { + + fn set_error(&mut self, err: PayloadError) { + self.inner.set_error(err) + } + + fn feed_eof(&mut self) { + if !self.error { + match self.payload.feed_eof() { + Err(err) => { + self.error = true; + self.set_error(PayloadError::Io(err)); + }, + Ok(value) => { + if let Some(b) = value { + self.inner.feed_data(b); + } + self.inner.feed_eof(); + } + } + } + } + + fn feed_data(&mut self, data: Bytes) { + if self.error { + return + } + + match self.payload.feed_data(data) { + Ok(Some(b)) => self.inner.feed_data(b), + Ok(None) => (), + Err(e) => { + self.error = true; + self.set_error(e.into()); + } + } + } + + #[inline] + fn need_read(&self) -> PayloadStatus { + self.inner.need_read() + } +} + +pub(crate) enum Decoder { + Deflate(Box>), Gzip(Option>>), - Br(Box>>), + Br(Box>), Identity, } // should go after write::GzDecoder get implemented #[derive(Debug)] -struct Wrapper { - buf: BytesMut, - eof: bool, +pub(crate) struct Wrapper { + pub buf: BytesMut, + pub eof: bool, } impl io::Read for Wrapper { @@ -169,50 +227,64 @@ impl io::Write for Wrapper { } } -/// Payload wrapper with content decompression support -pub(crate) struct EncodedPayload { - inner: PayloadSender, - decoder: Decoder, - dst: BytesMut, - error: bool, +pub(crate) struct Writer { + buf: BytesMut, } -impl EncodedPayload { - pub fn new(inner: PayloadSender, enc: ContentEncoding) -> EncodedPayload { +impl Writer { + fn new() -> Writer { + Writer{buf: BytesMut::with_capacity(8192)} + } + fn take(&mut self) -> Bytes { + self.buf.take().freeze() + } +} + +impl io::Write for Writer { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.buf.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +/// Payload stream with decompression support +pub(crate) struct PayloadStream { + decoder: Decoder, + dst: BytesMut, +} + +impl PayloadStream { + pub fn new(enc: ContentEncoding) -> PayloadStream { let dec = match enc { ContentEncoding::Br => Decoder::Br( - Box::new(BrotliDecoder::new(BytesMut::with_capacity(8192).writer()))), + Box::new(BrotliDecoder::new(Writer::new()))), ContentEncoding::Deflate => Decoder::Deflate( - Box::new(DeflateDecoder::new(BytesMut::with_capacity(8192).writer()))), + Box::new(DeflateDecoder::new(Writer::new()))), ContentEncoding::Gzip => Decoder::Gzip(None), _ => Decoder::Identity, }; - EncodedPayload{ inner: inner, decoder: dec, error: false, dst: BytesMut::new() } + PayloadStream{ decoder: dec, dst: BytesMut::new() } } } -impl PayloadWriter for EncodedPayload { +impl PayloadStream { - fn set_error(&mut self, err: PayloadError) { - self.inner.set_error(err) - } - - fn feed_eof(&mut self) { - if self.error { - return - } - let err = match self.decoder { + pub fn feed_eof(&mut self) -> io::Result> { + match self.decoder { Decoder::Br(ref mut decoder) => { match decoder.finish() { Ok(mut writer) => { - let b = writer.get_mut().take().freeze(); + let b = writer.take(); if !b.is_empty() { - self.inner.feed_data(b); + Ok(Some(b)) + } else { + Ok(None) } - self.inner.feed_eof(); - return }, - Err(err) => Some(err), + Err(e) => Err(e), } }, Decoder::Gzip(ref mut decoder) => { @@ -224,66 +296,50 @@ impl PayloadWriter for EncodedPayload { match decoder.read(unsafe{self.dst.bytes_mut()}) { Ok(n) => { if n == 0 { - self.inner.feed_eof(); - return + return Ok(Some(self.dst.take().freeze())) } else { - unsafe{self.dst.set_len(n)}; - self.inner.feed_data(self.dst.split_to(n).freeze()); + unsafe{self.dst.advance_mut(n)}; } } - Err(err) => { - break Some(err); - } + Err(e) => return Err(e), } } } else { - return + Ok(None) } }, Decoder::Deflate(ref mut decoder) => { match decoder.try_finish() { Ok(_) => { - let b = decoder.get_mut().get_mut().take().freeze(); + let b = decoder.get_mut().take(); if !b.is_empty() { - self.inner.feed_data(b); + Ok(Some(b)) + } else { + Ok(None) } - self.inner.feed_eof(); - return }, - Err(err) => Some(err), + Err(e) => Err(e), } }, - Decoder::Identity => { - self.inner.feed_eof(); - return - } - }; - - self.error = true; - self.decoder = Decoder::Identity; - if let Some(err) = err { - self.set_error(PayloadError::Io(err)); - } else { - self.set_error(PayloadError::Incomplete); + Decoder::Identity => Ok(None), } } - fn feed_data(&mut self, data: Bytes) { - if self.error { - return - } + pub fn feed_data(&mut self, data: Bytes) -> io::Result> { match self.decoder { Decoder::Br(ref mut decoder) => { - if decoder.write(&data).is_ok() && decoder.flush().is_ok() { - let b = decoder.get_mut().get_mut().take().freeze(); - if !b.is_empty() { - self.inner.feed_data(b); - } - return + match decoder.write(&data).and_then(|_| decoder.flush()) { + Ok(_) => { + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + }, + Err(e) => Err(e) } - trace!("Error decoding br encoding"); - } - + }, Decoder::Gzip(ref mut decoder) => { if decoder.is_none() { *decoder = Some( @@ -298,60 +354,52 @@ impl PayloadWriter for EncodedPayload { match decoder.as_mut().as_mut().unwrap().read(unsafe{self.dst.bytes_mut()}) { Ok(n) => { if n == 0 { - return + return Ok(Some(self.dst.take().freeze())); } else { - unsafe{self.dst.set_len(n)}; - self.inner.feed_data(self.dst.split_to(n).freeze()); + unsafe{self.dst.advance_mut(n)}; } } Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - return - } - break + return Err(e) } } } - } - + }, Decoder::Deflate(ref mut decoder) => { - if decoder.write(&data).is_ok() && decoder.flush().is_ok() { - let b = decoder.get_mut().get_mut().take().freeze(); - if !b.is_empty() { - self.inner.feed_data(b); - } - return + match decoder.write(&data).and_then(|_| decoder.flush()) { + Ok(_) => { + let b = decoder.get_mut().take(); + if !b.is_empty() { + Ok(Some(b)) + } else { + Ok(None) + } + }, + Err(e) => Err(e), } - trace!("Error decoding deflate encoding"); - } - Decoder::Identity => { - self.inner.feed_data(data); - return - } - }; - - self.error = true; - self.decoder = Decoder::Identity; - self.set_error(PayloadError::EncodingCorrupted); - } - - fn capacity(&self) -> usize { - self.inner.capacity() + }, + Decoder::Identity => Ok(Some(data)), + } } } -pub(crate) struct PayloadEncoder(ContentEncoder); +pub(crate) enum ContentEncoder { + Deflate(DeflateEncoder), + Gzip(GzEncoder), + Br(BrotliEncoder), + Identity(TransferEncoding), +} -impl PayloadEncoder { +impl ContentEncoder { - pub fn empty(bytes: SharedBytes) -> PayloadEncoder { - PayloadEncoder(ContentEncoder::Identity(TransferEncoding::eof(bytes))) + pub fn empty(bytes: SharedBytes) -> ContentEncoder { + ContentEncoder::Identity(TransferEncoding::eof(bytes)) } - pub fn new(buf: SharedBytes, - req: &HttpMessage, - resp: &mut HttpResponse, - response_encoding: ContentEncoding) -> PayloadEncoder + pub fn for_server(buf: SharedBytes, + req: &HttpInnerMessage, + resp: &mut HttpResponse, + response_encoding: ContentEncoding) -> ContentEncoder { let version = resp.version().unwrap_or_else(|| req.version); let mut body = resp.replace_body(Body::Empty); @@ -440,7 +488,7 @@ impl PayloadEncoder { } TransferEncoding::eof(buf) } else { - PayloadEncoder::streaming_encoding(buf, version, resp) + ContentEncoder::streaming_encoding(buf, version, resp) } } }; @@ -451,18 +499,16 @@ impl PayloadEncoder { resp.replace_body(body); } - PayloadEncoder( - match encoding { - ContentEncoding::Deflate => ContentEncoder::Deflate( - DeflateEncoder::new(transfer, Compression::default())), - ContentEncoding::Gzip => ContentEncoder::Gzip( - GzEncoder::new(transfer, Compression::default())), - ContentEncoding::Br => ContentEncoder::Br( - BrotliEncoder::new(transfer, 5)), - ContentEncoding::Identity => ContentEncoder::Identity(transfer), - ContentEncoding::Auto => unreachable!() - } - ) + match encoding { + ContentEncoding::Deflate => ContentEncoder::Deflate( + DeflateEncoder::new(transfer, Compression::default())), + ContentEncoding::Gzip => ContentEncoder::Gzip( + GzEncoder::new(transfer, Compression::default())), + ContentEncoding::Br => ContentEncoder::Br( + BrotliEncoder::new(transfer, 5)), + ContentEncoding::Identity => ContentEncoder::Identity(transfer), + ContentEncoding::Auto => unreachable!() + } } fn streaming_encoding(buf: SharedBytes, version: Version, @@ -527,33 +573,6 @@ impl PayloadEncoder { } } -impl PayloadEncoder { - - #[inline] - pub fn is_eof(&self) -> bool { - self.0.is_eof() - } - - #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - #[inline(always)] - pub fn write(&mut self, payload: Binary) -> Result<(), io::Error> { - self.0.write(payload) - } - - #[cfg_attr(feature = "cargo-clippy", allow(inline_always))] - #[inline(always)] - pub fn write_eof(&mut self) -> Result<(), io::Error> { - self.0.write_eof() - } -} - -pub(crate) enum ContentEncoder { - Deflate(DeflateEncoder), - Gzip(GzEncoder), - Br(BrotliEncoder), - Identity(TransferEncoding), -} - impl ContentEncoder { #[inline] @@ -829,10 +848,7 @@ impl AcceptEncoding { Err(_) => 0.0, } }; - Some(AcceptEncoding { - encoding: encoding, - quality: quality, - }) + Some(AcceptEncoding{ encoding, quality }) } /// Parse a raw Accept-Encoding header value into an ordered list. diff --git a/src/server/h1.rs b/src/server/h1.rs index 4ce403cb5..a55ac2799 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] + use std::{self, io}; use std::rc::Rc; use std::net::SocketAddr; @@ -13,10 +15,10 @@ use futures::{Future, Poll, Async}; use tokio_core::reactor::Timeout; use pipeline::Pipeline; -use httpcodes::HTTPNotFound; +use httpcodes::HttpNotFound; use httprequest::HttpRequest; use error::{ParseError, PayloadError, ResponseError}; -use payload::{Payload, PayloadWriter}; +use payload::{Payload, PayloadWriter, PayloadStatus}; use super::{utils, Writer}; use super::h1writer::H1Writer; @@ -62,18 +64,20 @@ struct Entry { impl Http1 where T: IoStream, H: HttpHandler + 'static { - pub fn new(h: Rc>, stream: T, addr: Option, buf: BytesMut) - -> Self + pub fn new(settings: Rc>, + stream: T, + addr: Option, read_buf: BytesMut) -> Self { - let bytes = h.get_shared_bytes(); + let bytes = settings.get_shared_bytes(); Http1{ flags: Flags::KEEPALIVE, - settings: h, - addr: addr, stream: H1Writer::new(stream, bytes), reader: Reader::new(), - read_buf: buf, tasks: VecDeque::new(), - keepalive_timer: None } + keepalive_timer: None, + addr, + read_buf, + settings, + } } pub fn settings(&self) -> &WorkerSettings { @@ -84,18 +88,6 @@ impl Http1 self.stream.get_mut() } - fn poll_completed(&mut self, shutdown: bool) -> Result { - // check stream state - match self.stream.poll_completed(shutdown) { - Ok(Async::Ready(_)) => Ok(true), - Ok(Async::NotReady) => Ok(false), - Err(err) => { - debug!("Error sending data: {}", err); - Err(()) - } - } - } - pub fn poll(&mut self) -> Poll<(), ()> { // keep-alive timer if let Some(ref mut timer) = self.keepalive_timer { @@ -109,14 +101,32 @@ impl Http1 } } - self.poll_io() + loop { + match self.poll_io()? { + Async::Ready(true) => (), + Async::Ready(false) => return Ok(Async::Ready(())), + Async::NotReady => return Ok(Async::NotReady), + } + } + } + + fn poll_completed(&mut self, shutdown: bool) -> Result { + // check stream state + match self.stream.poll_completed(shutdown) { + Ok(Async::Ready(_)) => Ok(true), + Ok(Async::NotReady) => Ok(false), + Err(err) => { + debug!("Error sending data: {}", err); + Err(()) + } + } } // TODO: refactor - pub fn poll_io(&mut self) -> Poll<(), ()> { + pub fn poll_io(&mut self) -> Poll { // read incoming data - let need_read = - if !self.flags.contains(Flags::ERROR) && self.tasks.len() < MAX_PIPELINED_MESSAGES + let need_read = if !self.flags.intersects(Flags::ERROR) && + self.tasks.len() < MAX_PIPELINED_MESSAGES { 'outer: loop { match self.reader.parse(self.stream.get_mut(), @@ -131,9 +141,9 @@ impl Http1 // start request processing for h in self.settings.handlers().iter_mut() { req = match h.handle(req) { - Ok(t) => { + Ok(pipe) => { self.tasks.push_back( - Entry {pipe: t, flags: EntryFlags::empty()}); + Entry {pipe, flags: EntryFlags::empty()}); continue 'outer }, Err(req) => req, @@ -141,18 +151,11 @@ impl Http1 } self.tasks.push_back( - Entry {pipe: Pipeline::error(HTTPNotFound), + Entry {pipe: Pipeline::error(HttpNotFound), flags: EntryFlags::empty()}); continue }, Ok(Async::NotReady) => (), - Err(ReaderError::Disconnect) => { - self.flags.insert(Flags::ERROR); - self.stream.disconnected(); - for entry in &mut self.tasks { - entry.pipe.disconnected() - } - }, Err(err) => { // notify all tasks self.stream.disconnected(); @@ -167,12 +170,16 @@ impl Http1 // on parse error, stop reading stream but tasks need to be completed self.flags.insert(Flags::ERROR); - if self.tasks.is_empty() { - if let ReaderError::Error(err) = err { - self.tasks.push_back( - Entry {pipe: Pipeline::error(err.error_response()), - flags: EntryFlags::empty()}); - } + match err { + ReaderError::Disconnect => (), + _ => + if self.tasks.is_empty() { + if let ReaderError::Error(err) = err { + self.tasks.push_back( + Entry {pipe: Pipeline::error(err.error_response()), + flags: EntryFlags::empty()}); + } + } } }, } @@ -183,6 +190,8 @@ impl Http1 true }; + let retry = self.reader.need_read() == PayloadStatus::Read; + loop { // check in-flight messages let mut io = false; @@ -217,7 +226,12 @@ impl Http1 } }, // no more IO for this iteration - Ok(Async::NotReady) => io = true, + Ok(Async::NotReady) => { + if self.reader.need_read() == PayloadStatus::Read && !retry { + return Ok(Async::Ready(true)); + } + io = true; + } Err(err) => { // it is not possible to recover from error // during pipe handling, so just drop connection @@ -264,14 +278,14 @@ impl Http1 if !self.poll_completed(true)? { return Ok(Async::NotReady) } - return Ok(Async::Ready(())) + return Ok(Async::Ready(false)) } // start keep-alive timer, this also is slow request timeout if self.tasks.is_empty() { // check stream state if self.flags.contains(Flags::ERROR) { - return Ok(Async::Ready(())) + return Ok(Async::Ready(false)) } if self.settings.keep_alive_enabled() { @@ -291,7 +305,7 @@ impl Http1 return Ok(Async::NotReady) } // keep-alive is disabled, drop connection - return Ok(Async::Ready(())) + return Ok(Async::Ready(false)) } } else if !self.poll_completed(false)? || self.flags.contains(Flags::KEEPALIVE) { @@ -299,7 +313,7 @@ impl Http1 // if keep-alive unset, rely on operating system return Ok(Async::NotReady) } else { - return Ok(Async::Ready(())) + return Ok(Async::Ready(false)) } } else { self.poll_completed(false)?; @@ -327,6 +341,7 @@ struct PayloadInfo { enum ReaderError { Disconnect, Payload, + PayloadDropped, Error(ParseError), } @@ -337,14 +352,27 @@ impl Reader { } } + #[inline] + fn need_read(&self) -> PayloadStatus { + if let Some(ref info) = self.payload { + info.tx.need_read() + } else { + PayloadStatus::Read + } + } + #[inline] fn decode(&mut self, buf: &mut BytesMut, payload: &mut PayloadInfo) -> Result { - loop { + while !buf.is_empty() { match payload.decoder.decode(buf) { Ok(Async::Ready(Some(bytes))) => { - payload.tx.feed_data(bytes) + payload.tx.feed_data(bytes); + if payload.decoder.is_eof() { + payload.tx.feed_eof(); + return Ok(Decoding::Ready) + } }, Ok(Async::Ready(None)) => { payload.tx.feed_eof(); @@ -357,6 +385,7 @@ impl Reader { } } } + Ok(Decoding::NotReady) } pub fn parse(&mut self, io: &mut T, @@ -364,42 +393,58 @@ impl Reader { settings: &WorkerSettings) -> Poll where T: IoStream { + match self.need_read() { + PayloadStatus::Read => (), + PayloadStatus::Pause => return Ok(Async::NotReady), + PayloadStatus::Dropped => return Err(ReaderError::PayloadDropped), + } + // read payload let done = { if let Some(ref mut payload) = self.payload { - if payload.tx.capacity() == 0 { - return Ok(Async::NotReady) - } - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - payload.tx.set_error(PayloadError::Incomplete); + 'buf: loop { + let not_ready = match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => { + payload.tx.set_error(PayloadError::Incomplete); - // http channel should not deal with payload errors - return Err(ReaderError::Payload) - }, - Err(err) => { - payload.tx.set_error(err.into()); - - // http channel should not deal with payload errors - return Err(ReaderError::Payload) - } - _ => (), - } - loop { - match payload.decoder.decode(buf) { - Ok(Async::Ready(Some(bytes))) => { - payload.tx.feed_data(bytes) + // http channel should not deal with payload errors + return Err(ReaderError::Payload) }, - Ok(Async::Ready(None)) => { - payload.tx.feed_eof(); - break true - }, - Ok(Async::NotReady) => - break false, + Ok(Async::NotReady) => true, Err(err) => { payload.tx.set_error(err.into()); + + // http channel should not deal with payload errors return Err(ReaderError::Payload) } + _ => false, + }; + loop { + match payload.decoder.decode(buf) { + Ok(Async::Ready(Some(bytes))) => { + payload.tx.feed_data(bytes); + if payload.decoder.is_eof() { + payload.tx.feed_eof(); + break 'buf true + } + }, + Ok(Async::Ready(None)) => { + payload.tx.feed_eof(); + break 'buf true + }, + Ok(Async::NotReady) => { + // if buffer is full then + // socket still can contain more data + if not_ready { + return Ok(Async::NotReady) + } + continue 'buf + }, + Err(err) => { + payload.tx.set_error(err.into()); + return Err(ReaderError::Payload) + } + } } } } else { @@ -409,16 +454,13 @@ impl Reader { if done { self.payload = None } // if buf is empty parse_message will always return NotReady, let's avoid that - let read = if buf.is_empty() { + if buf.is_empty() { match utils::read_from_io(io, buf) { Ok(Async::Ready(0)) => return Err(ReaderError::Disconnect), Ok(Async::Ready(_)) => (), Ok(Async::NotReady) => return Ok(Async::NotReady), Err(err) => return Err(ReaderError::Error(err.into())) } - false - } else { - true }; loop { @@ -434,22 +476,18 @@ impl Reader { return Ok(Async::Ready(msg)); }, Async::NotReady => { - if buf.capacity() >= MAX_BUFFER_SIZE { + if buf.len() >= MAX_BUFFER_SIZE { error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); return Err(ReaderError::Error(ParseError::TooLarge)); } - if read { - match utils::read_from_io(io, buf) { - Ok(Async::Ready(0)) => { - debug!("Ignored premature client disconnection"); - return Err(ReaderError::Disconnect); - }, - Ok(Async::Ready(_)) => (), - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => return Err(ReaderError::Error(err.into())), - } - } else { - return Ok(Async::NotReady) + match utils::read_from_io(io, buf) { + Ok(Async::Ready(0)) => { + debug!("Ignored premature client disconnection"); + return Err(ReaderError::Disconnect); + }, + Ok(Async::Ready(_)) => (), + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => return Err(ReaderError::Error(err.into())), } }, } @@ -540,7 +578,7 @@ impl Reader { let (psender, payload) = Payload::new(false); let info = PayloadInfo { tx: PayloadType::new(&msg.get_mut().headers, psender), - decoder: decoder, + decoder, }; msg.get_mut().payload = Some(payload); Ok(Async::Ready((HttpRequest::from_message(msg), Some(info)))) @@ -624,6 +662,13 @@ enum ChunkedState { } impl Decoder { + pub fn is_eof(&self) -> bool { + match self.kind { + Kind::Length(0) | Kind::Chunked(ChunkedState::End, _) | Kind::Eof(true) => true, + _ => false, + } + } + pub fn decode(&mut self, body: &mut BytesMut) -> Poll, io::Error> { match self.kind { Kind::Length(ref mut remaining) => { @@ -815,11 +860,12 @@ mod tests { use std::{io, cmp, time}; use std::net::Shutdown; use bytes::{Bytes, BytesMut, Buf}; - use futures::Async; + use futures::{Async, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; use http::{Version, Method}; use super::*; + use httpmessage::HttpMessage; use application::HttpApplication; use server::settings::WorkerSettings; use server::IoStream; @@ -1320,6 +1366,7 @@ mod tests { assert!(!req.payload().eof()); buf.feed_data("4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); + let _ = req.payload_mut().poll(); not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); assert!(!req.payload().eof()); assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"dataline"); @@ -1344,6 +1391,7 @@ mod tests { "4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ POST /test2 HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); + let _ = req.payload_mut().poll(); let req2 = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); assert_eq!(*req2.method(), Method::POST); @@ -1367,6 +1415,10 @@ mod tests { assert!(req.chunked().unwrap()); assert!(!req.payload().eof()); + buf.feed_data("4\r\n1111\r\n"); + not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); + assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"1111"); + buf.feed_data("4\r\ndata\r"); not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); @@ -1384,13 +1436,18 @@ mod tests { buf.feed_data("ne\r\n0\r\n"); not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); + //trailers //buf.feed_data("test: test\r\n"); //not_ready!(reader.parse(&mut buf, &mut readbuf)); + let _ = req.payload_mut().poll(); + not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); + assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"dataline"); assert!(!req.payload().eof()); buf.feed_data("\r\n"); + let _ = req.payload_mut().poll(); not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); assert!(req.payload().eof()); } @@ -1409,6 +1466,7 @@ mod tests { assert!(!req.payload().eof()); buf.feed_data("4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") + let _ = req.payload_mut().poll(); not_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); assert!(!req.payload().eof()); assert_eq!(req.payload_mut().readall().unwrap().as_ref(), b"dataline"); diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index aa9c819d7..80d02f292 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] + use std::{io, mem}; use bytes::BufMut; use futures::{Async, Poll}; @@ -8,11 +10,11 @@ use http::header::{HeaderValue, CONNECTION, DATE}; use helpers; use body::{Body, Binary}; use headers::ContentEncoding; -use httprequest::HttpMessage; +use httprequest::HttpInnerMessage; use httpresponse::HttpResponse; use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE}; use super::shared::SharedBytes; -use super::encoding::PayloadEncoder; +use super::encoding::ContentEncoder; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific @@ -28,7 +30,7 @@ bitflags! { pub(crate) struct H1Writer { flags: Flags, stream: T, - encoder: PayloadEncoder, + encoder: ContentEncoder, written: u64, headers_size: u32, buffer: SharedBytes, @@ -39,11 +41,11 @@ impl H1Writer { pub fn new(stream: T, buf: SharedBytes) -> H1Writer { H1Writer { flags: Flags::empty(), - stream: stream, - encoder: PayloadEncoder::empty(buf.clone()), + encoder: ContentEncoder::empty(buf.clone()), written: 0, headers_size: 0, buffer: buf, + stream, } } @@ -96,12 +98,12 @@ impl Writer for H1Writer { } fn start(&mut self, - req: &mut HttpMessage, + req: &mut HttpInnerMessage, msg: &mut HttpResponse, encoding: ContentEncoding) -> io::Result { // prepare task - self.encoder = PayloadEncoder::new(self.buffer.clone(), req, msg, encoding); + self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding); if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { self.flags.insert(Flags::STARTED | Flags::KEEPALIVE); } else { diff --git a/src/server/h2.rs b/src/server/h2.rs index c843fee89..02951593e 100644 --- a/src/server/h2.rs +++ b/src/server/h2.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] + use std::{io, cmp, mem}; use std::rc::Rc; use std::io::{Read, Write}; @@ -16,9 +18,10 @@ use tokio_core::reactor::Timeout; use pipeline::Pipeline; use error::PayloadError; -use httpcodes::HTTPNotFound; +use httpcodes::HttpNotFound; +use httpmessage::HttpMessage; use httprequest::HttpRequest; -use payload::{Payload, PayloadWriter}; +use payload::{Payload, PayloadWriter, PayloadStatus}; use super::h2writer::H2Writer; use super::encoding::PayloadType; @@ -32,7 +35,8 @@ bitflags! { } /// HTTP/2 Transport -pub(crate) struct Http2 +pub(crate) +struct Http2 where T: AsyncRead + AsyncWrite + 'static, H: 'static { flags: Flags, @@ -53,15 +57,17 @@ impl Http2 where T: AsyncRead + AsyncWrite + 'static, H: HttpHandler + 'static { - pub fn new(h: Rc>, io: T, addr: Option, buf: Bytes) -> Self + pub fn new(settings: Rc>, + io: T, + addr: Option, buf: Bytes) -> Self { Http2{ flags: Flags::empty(), - settings: h, - addr: addr, tasks: VecDeque::new(), state: State::Handshake( server::handshake(IoWrapper{unread: Some(buf), inner: io})), keepalive_timer: None, + addr, + settings, } } @@ -99,21 +105,30 @@ impl Http2 item.poll_payload(); if !item.flags.contains(EntryFlags::EOF) { - match item.task.poll_io(&mut item.stream) { - Ok(Async::Ready(ready)) => { - item.flags.insert(EntryFlags::EOF); - if ready { - item.flags.insert(EntryFlags::FINISHED); + let retry = item.payload.need_read() == PayloadStatus::Read; + loop { + match item.task.poll_io(&mut item.stream) { + Ok(Async::Ready(ready)) => { + item.flags.insert(EntryFlags::EOF); + if ready { + item.flags.insert(EntryFlags::FINISHED); + } + not_ready = false; + }, + Ok(Async::NotReady) => { + if item.payload.need_read() == PayloadStatus::Read && !retry + { + continue + } + }, + Err(err) => { + error!("Unhandled error: {}", err); + item.flags.insert(EntryFlags::EOF); + item.flags.insert(EntryFlags::ERROR); + item.stream.reset(Reason::INTERNAL_ERROR); } - not_ready = false; - }, - Ok(Async::NotReady) => (), - Err(err) => { - error!("Unhandled error: {}", err); - item.flags.insert(EntryFlags::EOF); - item.flags.insert(EntryFlags::ERROR); - item.stream.reset(Reason::INTERNAL_ERROR); } + break } } else if !item.flags.contains(EntryFlags::FINISHED) { match item.task.poll() { @@ -244,7 +259,6 @@ struct Entry { payload: PayloadType, recv: RecvStream, stream: H2Writer, - capacity: usize, flags: EntryFlags, } @@ -284,17 +298,24 @@ impl Entry { } } - Entry {task: task.unwrap_or_else(|| Pipeline::error(HTTPNotFound)), + Entry {task: task.unwrap_or_else(|| Pipeline::error(HttpNotFound)), payload: psender, - recv: recv, stream: H2Writer::new(resp, settings.get_shared_bytes()), flags: EntryFlags::empty(), - capacity: 0, + recv, } } fn poll_payload(&mut self) { if !self.flags.contains(EntryFlags::REOF) { + if self.payload.need_read() == PayloadStatus::Read { + if let Err(err) = self.recv.release_capacity().release_capacity(32_768) { + self.payload.set_error(PayloadError::Http2(err)) + } + } else if let Err(err) = self.recv.release_capacity().release_capacity(0) { + self.payload.set_error(PayloadError::Http2(err)) + } + match self.recv.poll() { Ok(Async::Ready(Some(chunk))) => { self.payload.feed_data(chunk); @@ -307,14 +328,6 @@ impl Entry { self.payload.set_error(PayloadError::Http2(err)) } } - - let capacity = self.payload.capacity(); - if self.capacity != capacity { - self.capacity = capacity; - if let Err(err) = self.recv.release_capacity().release_capacity(capacity) { - self.payload.set_error(PayloadError::Http2(err)) - } - } } } } diff --git a/src/server/h2writer.rs b/src/server/h2writer.rs index 00a981915..095cd78f2 100644 --- a/src/server/h2writer.rs +++ b/src/server/h2writer.rs @@ -1,3 +1,5 @@ +#![cfg_attr(feature = "cargo-clippy", allow(redundant_field_names))] + use std::{io, cmp}; use bytes::{Bytes, BytesMut}; use futures::{Async, Poll}; @@ -9,9 +11,9 @@ use http::header::{HeaderValue, CONNECTION, TRANSFER_ENCODING, DATE, CONTENT_LEN use helpers; use body::{Body, Binary}; use headers::ContentEncoding; -use httprequest::HttpMessage; +use httprequest::HttpInnerMessage; use httpresponse::HttpResponse; -use super::encoding::PayloadEncoder; +use super::encoding::ContentEncoder; use super::shared::SharedBytes; use super::{Writer, WriterState, MAX_WRITE_BUFFER_SIZE}; @@ -28,7 +30,7 @@ bitflags! { pub(crate) struct H2Writer { respond: SendResponse, stream: Option>, - encoder: PayloadEncoder, + encoder: ContentEncoder, flags: Flags, written: u64, buffer: SharedBytes, @@ -38,9 +40,9 @@ impl H2Writer { pub fn new(respond: SendResponse, buf: SharedBytes) -> H2Writer { H2Writer { - respond: respond, + respond, stream: None, - encoder: PayloadEncoder::empty(buf.clone()), + encoder: ContentEncoder::empty(buf.clone()), flags: Flags::empty(), written: 0, buffer: buf, @@ -109,11 +111,11 @@ impl Writer for H2Writer { self.written } - fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse, encoding: ContentEncoding) + fn start(&mut self, req: &mut HttpInnerMessage, msg: &mut HttpResponse, encoding: ContentEncoding) -> io::Result { // prepare response self.flags.insert(Flags::STARTED); - self.encoder = PayloadEncoder::new(self.buffer.clone(), req, msg, encoding); + self.encoder = ContentEncoder::for_server(self.buffer.clone(), req, msg, encoding); if let Body::Empty = *msg.body() { self.flags.insert(Flags::EOF); } diff --git a/src/server/mod.rs b/src/server/mod.rs index 3769e588e..9f644a1e9 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -25,7 +25,7 @@ pub use self::settings::ServerSettings; use body::Binary; use error::Error; use headers::ContentEncoding; -use httprequest::{HttpMessage, HttpRequest}; +use httprequest::{HttpInnerMessage, HttpRequest}; use httpresponse::HttpResponse; /// max buffer size 64k @@ -103,7 +103,7 @@ pub enum WriterState { pub trait Writer { fn written(&self) -> u64; - fn start(&mut self, req: &mut HttpMessage, resp: &mut HttpResponse, encoding: ContentEncoding) + fn start(&mut self, req: &mut HttpInnerMessage, resp: &mut HttpResponse, encoding: ContentEncoding) -> io::Result; fn write(&mut self, payload: Binary) -> io::Result; diff --git a/src/server/settings.rs b/src/server/settings.rs index 0ca4b4371..b0b8eb552 100644 --- a/src/server/settings.rs +++ b/src/server/settings.rs @@ -36,11 +36,7 @@ impl ServerSettings { } else { "localhost".to_owned() }; - ServerSettings { - addr: addr, - secure: secure, - host: host, - } + ServerSettings { addr, secure, host } } /// Returns the socket address of the local half of this TCP connection @@ -67,7 +63,7 @@ pub(crate) struct WorkerSettings { bytes: Rc, messages: Rc, channels: Cell, - node: Node<()>, + node: Box>, } impl WorkerSettings { @@ -79,7 +75,7 @@ impl WorkerSettings { bytes: Rc::new(SharedBytesPool::new()), messages: Rc::new(helpers::SharedMessagePool::new()), channels: Cell::new(0), - node: Node::head(), + node: Box::new(Node::head()), } } @@ -107,8 +103,8 @@ impl WorkerSettings { SharedBytes::new(self.bytes.get_bytes(), Rc::clone(&self.bytes)) } - pub fn get_http_message(&self) -> helpers::SharedHttpMessage { - helpers::SharedHttpMessage::new(self.messages.get(), Rc::clone(&self.messages)) + pub fn get_http_message(&self) -> helpers::SharedHttpInnerMessage { + helpers::SharedHttpInnerMessage::new(self.messages.get(), Rc::clone(&self.messages)) } pub fn add_channel(&self) { diff --git a/src/server/srv.rs b/src/server/srv.rs index 63f23d245..e219049ba 100644 --- a/src/server/srv.rs +++ b/src/server/srv.rs @@ -261,7 +261,7 @@ impl HttpServer /// /// HttpServer::new( /// || Application::new() - /// .resource("/", |r| r.h(httpcodes::HTTPOk))) + /// .resource("/", |r| r.h(httpcodes::HttpOk))) /// .bind("127.0.0.1:0").expect("Can not bind to 127.0.0.1:0") /// .start(); /// # actix::Arbiter::system().do_send(actix::msgs::SystemExit(0)); @@ -312,7 +312,7 @@ impl HttpServer /// fn main() { /// HttpServer::new( /// || Application::new() - /// .resource("/", |r| r.h(httpcodes::HTTPOk))) + /// .resource("/", |r| r.h(httpcodes::HttpOk))) /// .bind("127.0.0.1:0").expect("Can not bind to 127.0.0.1:0") /// .run(); /// } @@ -697,7 +697,7 @@ fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result TcpBuilder::new_v4()?, net::SocketAddr::V6(_) => TcpBuilder::new_v6()?, }; - builder.bind(addr)?; builder.reuse_address(true)?; + builder.bind(addr)?; Ok(builder.listen(backlog)?) } diff --git a/src/server/utils.rs b/src/server/utils.rs index 79e0a11c5..bbc890e94 100644 --- a/src/server/utils.rs +++ b/src/server/utils.rs @@ -5,7 +5,7 @@ use futures::{Async, Poll}; use super::IoStream; const LW_BUFFER_SIZE: usize = 4096; -const HW_BUFFER_SIZE: usize = 16_384; +const HW_BUFFER_SIZE: usize = 32_768; pub fn read_from_io(io: &mut T, buf: &mut BytesMut) -> Poll { diff --git a/src/server/worker.rs b/src/server/worker.rs index 23e8a6c61..5257d8615 100644 --- a/src/server/worker.rs +++ b/src/server/worker.rs @@ -62,7 +62,7 @@ impl Worker { Worker { settings: Rc::new(WorkerSettings::new(h, keep_alive)), hnd: Arbiter::handle().clone(), - handler: handler, + handler, } } diff --git a/src/test.rs b/src/test.rs index faad063f2..8f5519459 100644 --- a/src/test.rs +++ b/src/test.rs @@ -14,6 +14,7 @@ use tokio_core::net::TcpListener; use tokio_core::reactor::Core; use net2::TcpBuilder; +use ws; use body::Binary; use error::Error; use handler::{Handler, Responder, ReplyItem}; @@ -25,7 +26,6 @@ use payload::Payload; use httprequest::HttpRequest; use httpresponse::HttpResponse; use server::{HttpServer, IntoHttpHandler, ServerSettings}; -use ws::{WsClient, WsClientError, WsClientReader, WsClientWriter}; use client::{ClientRequest, ClientRequestBuilder}; /// The `TestServer` type. @@ -41,7 +41,7 @@ use client::{ClientRequest, ClientRequestBuilder}; /// # use actix_web::*; /// # /// # fn my_handler(req: HttpRequest) -> HttpResponse { -/// # httpcodes::HTTPOk.into() +/// # httpcodes::HttpOk.into() /// # } /// # /// # fn main() { @@ -94,12 +94,12 @@ impl TestServer { let _ = sys.run(); }); - let (sys, addr) = rx.recv().unwrap(); + let (server_sys, addr) = rx.recv().unwrap(); TestServer { - addr: addr, + addr, thread: Some(join), system: System::new("actix-test"), - server_sys: sys, + server_sys, } } @@ -131,12 +131,12 @@ impl TestServer { let _ = sys.run(); }); - let (sys, addr) = rx.recv().unwrap(); + let (server_sys, addr) = rx.recv().unwrap(); TestServer { - addr: addr, + addr, + server_sys, thread: Some(join), system: System::new("actix-test"), - server_sys: sys, } } @@ -180,9 +180,9 @@ impl TestServer { } /// Connect to websocket server - pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> { + pub fn ws(&mut self) -> Result<(ws::ClientReader, ws::ClientWriter), ws::ClientError> { let url = self.url("/"); - self.system.run_until_complete(WsClient::new(url).connect().unwrap()) + self.system.run_until_complete(ws::Client::new(url).connect()) } /// Create `GET` request @@ -282,9 +282,9 @@ impl Iterator for TestApp { /// /// fn index(req: HttpRequest) -> HttpResponse { /// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { -/// httpcodes::HTTPOk.into() +/// httpcodes::HttpOk.into() /// } else { -/// httpcodes::HTTPBadRequest.into() +/// httpcodes::HttpBadRequest.into() /// } /// } /// @@ -346,7 +346,7 @@ impl TestRequest { /// Start HttpRequest build process with application state pub fn with_state(state: S) -> TestRequest { TestRequest { - state: state, + state, method: Method::GET, uri: Uri::from_str("/").unwrap(), version: Version::HTTP_11, @@ -403,7 +403,7 @@ impl TestRequest { self.payload = Some(payload); self } - + /// Complete request creation and generate `HttpRequest` instance pub fn finish(self) -> HttpRequest { let TestRequest { state, method, uri, version, headers, params, cookies, payload } = self; @@ -434,7 +434,7 @@ impl TestRequest { let req = self.finish(); let resp = h.handle(req.clone()); - match resp.respond_to(req.clone_without_state()) { + match resp.respond_to(req.without_state()) { Ok(resp) => { match resp.into().into() { ReplyItem::Message(resp) => Ok(resp), @@ -461,7 +461,7 @@ impl TestRequest { let mut core = Core::new().unwrap(); match core.run(fut) { Ok(r) => { - match r.respond_to(req.clone_without_state()) { + match r.respond_to(req.without_state()) { Ok(reply) => match reply.into().into() { ReplyItem::Message(resp) => Ok(resp), _ => panic!("Nested async replies are not supported"), diff --git a/src/ws/client.rs b/src/ws/client.rs index 1d34b864b..c8fdec0ff 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -1,53 +1,70 @@ //! Http client request -#![allow(unused_imports, dead_code)] use std::{fmt, io, str}; use std::rc::Rc; -use std::time::Duration; use std::cell::UnsafeCell; use base64; use rand; +use bytes::Bytes; use cookie::Cookie; -use bytes::BytesMut; use http::{HttpTryFrom, StatusCode, Error as HttpError}; use http::header::{self, HeaderName, HeaderValue}; use sha1::Sha1; use futures::{Async, Future, Poll, Stream}; -use futures::future::{Either, err as FutErr}; -use tokio_core::net::TcpStream; +use futures::unsync::mpsc::{unbounded, UnboundedSender}; +use byteorder::{ByteOrder, NetworkEndian}; use actix::prelude::*; -use body::Binary; +use body::{Body, Binary}; use error::UrlParseError; -use server::shared::SharedBytes; +use payload::PayloadHelper; +use httpmessage::HttpMessage; -use server::{utils, IoStream}; -use client::{ClientRequest, ClientRequestBuilder, - HttpResponseParser, HttpResponseParserError, HttpClientWriter}; -use client::{Connect, Connection, ClientConnector, ClientConnectorError}; +use client::{ClientRequest, ClientRequestBuilder, ClientResponse, + ClientConnector, SendRequest, SendRequestError, + HttpResponseParserError}; -use super::Message; +use super::{Message, ProtocolError}; use super::frame::Frame; use super::proto::{CloseCode, OpCode}; -pub type WsClientFuture = - Future; + +/// Backward compatibility +#[doc(hidden)] +#[deprecated(since="0.4.2", note="please use `ws::Client` instead")] +pub type WsClient = Client; +#[doc(hidden)] +#[deprecated(since="0.4.2", note="please use `ws::ClientError` instead")] +pub type WsClientError = ClientError; +#[doc(hidden)] +#[deprecated(since="0.4.2", note="please use `ws::ClientReader` instead")] +pub type WsClientReader = ClientReader; +#[doc(hidden)] +#[deprecated(since="0.4.2", note="please use `ws::ClientWriter` instead")] +pub type WsClientWriter = ClientWriter; +#[doc(hidden)] +#[deprecated(since="0.4.2", note="please use `ws::ClientHandshake` instead")] +pub type WsClientHandshake = ClientHandshake; /// Websocket client error #[derive(Fail, Debug)] -pub enum WsClientError { +pub enum ClientError { #[fail(display="Invalid url")] InvalidUrl, #[fail(display="Invalid response status")] - InvalidResponseStatus, + InvalidResponseStatus(StatusCode), #[fail(display="Invalid upgrade header")] InvalidUpgradeHeader, #[fail(display="Invalid connection header")] - InvalidConnectionHeader, + InvalidConnectionHeader(HeaderValue), + #[fail(display="Missing CONNECTION header")] + MissingConnectionHeader, + #[fail(display="Missing SEC-WEBSOCKET-ACCEPT header")] + MissingWebSocketAcceptHeader, #[fail(display="Invalid challenge response")] - InvalidChallengeResponse, + InvalidChallengeResponse(String, HeaderValue), #[fail(display="Http parsing error")] Http(HttpError), #[fail(display="Url parsing error")] @@ -55,40 +72,48 @@ pub enum WsClientError { #[fail(display="Response parsing error")] ResponseParseError(HttpResponseParserError), #[fail(display="{}", _0)] - Connector(ClientConnectorError), + SendRequest(SendRequestError), + #[fail(display="{}", _0)] + Protocol(#[cause] ProtocolError), #[fail(display="{}", _0)] Io(io::Error), #[fail(display="Disconnected")] Disconnected, } -impl From for WsClientError { - fn from(err: HttpError) -> WsClientError { - WsClientError::Http(err) +impl From for ClientError { + fn from(err: HttpError) -> ClientError { + ClientError::Http(err) } } -impl From for WsClientError { - fn from(err: UrlParseError) -> WsClientError { - WsClientError::Url(err) +impl From for ClientError { + fn from(err: UrlParseError) -> ClientError { + ClientError::Url(err) } } -impl From for WsClientError { - fn from(err: ClientConnectorError) -> WsClientError { - WsClientError::Connector(err) +impl From for ClientError { + fn from(err: SendRequestError) -> ClientError { + ClientError::SendRequest(err) } } -impl From for WsClientError { - fn from(err: io::Error) -> WsClientError { - WsClientError::Io(err) +impl From for ClientError { + fn from(err: ProtocolError) -> ClientError { + ClientError::Protocol(err) } } -impl From for WsClientError { - fn from(err: HttpResponseParserError) -> WsClientError { - WsClientError::ResponseParseError(err) +impl From for ClientError { + fn from(err: io::Error) -> ClientError { + ClientError::Io(err) + } +} + +impl From for ClientError { + fn from(err: HttpResponseParserError) -> ClientError { + ClientError::ResponseParseError(err) } } @@ -97,38 +122,40 @@ impl From for WsClientError { /// Example of `WebSocket` client usage is available in /// [websocket example]( /// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24) -pub struct WsClient { +pub struct Client { request: ClientRequestBuilder, - err: Option, + err: Option, http_err: Option, origin: Option, protocols: Option, conn: Addr, + max_size: usize, } -impl WsClient { +impl Client { /// Create new websocket connection - pub fn new>(uri: S) -> WsClient { - WsClient::with_connector(uri, ClientConnector::from_registry()) + pub fn new>(uri: S) -> Client { + Client::with_connector(uri, ClientConnector::from_registry()) } /// Create new websocket connection with custom `ClientConnector` - pub fn with_connector>(uri: S, conn: Addr) -> WsClient { - let mut cl = WsClient { + pub fn with_connector>(uri: S, conn: Addr) -> Client { + let mut cl = Client { request: ClientRequest::build(), err: None, http_err: None, origin: None, protocols: None, - conn: conn, + max_size: 65_536, + conn, }; cl.request.uri(uri.as_ref()); cl } /// Set supported websocket protocols - pub fn protocols(&mut self, protos: U) -> &mut Self + pub fn protocols(mut self, protos: U) -> Self where U: IntoIterator + 'static, V: AsRef { @@ -140,13 +167,13 @@ impl WsClient { } /// Set cookie for handshake request - pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { + pub fn cookie(mut self, cookie: Cookie) -> Self { self.request.cookie(cookie); self } /// Set request Origin - pub fn origin(&mut self, origin: V) -> &mut Self + pub fn origin(mut self, origin: V) -> Self where HeaderValue: HttpTryFrom { match HeaderValue::try_from(origin) { @@ -156,8 +183,16 @@ impl WsClient { self } + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_frame_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + /// Set request header - pub fn header(&mut self, key: K, value: V) -> &mut Self + pub fn header(mut self, key: K, value: V) -> Self where HeaderName: HttpTryFrom, HeaderValue: HttpTryFrom { self.request.header(key, value); @@ -165,70 +200,66 @@ impl WsClient { } /// Connect to websocket server and do ws handshake - pub fn connect(&mut self) -> Result, WsClientError> { + pub fn connect(&mut self) -> ClientHandshake { if let Some(e) = self.err.take() { - return Err(e) + ClientHandshake::error(e) } - if let Some(e) = self.http_err.take() { - return Err(e.into()) - } - - // origin - if let Some(origin) = self.origin.take() { - self.request.set_header(header::ORIGIN, origin); - } - - self.request.upgrade(); - self.request.set_header(header::UPGRADE, "websocket"); - self.request.set_header(header::CONNECTION, "upgrade"); - self.request.set_header("SEC-WEBSOCKET-VERSION", "13"); - - if let Some(protocols) = self.protocols.take() { - self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str()); - } - let request = self.request.finish()?; - - if request.uri().host().is_none() { - return Err(WsClientError::InvalidUrl) - } - if let Some(scheme) = request.uri().scheme_part() { - if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { - return Err(WsClientError::InvalidUrl); - } + else if let Some(e) = self.http_err.take() { + ClientHandshake::error(e.into()) } else { - return Err(WsClientError::InvalidUrl); - } + // origin + if let Some(origin) = self.origin.take() { + self.request.set_header(header::ORIGIN, origin); + } - // get connection and start handshake - Ok(Box::new( - self.conn.send(Connect(request.uri().clone())) - .map_err(|_| WsClientError::Disconnected) - .and_then(|res| match res { - Ok(stream) => Either::A(WsHandshake::new(stream, request)), - Err(err) => Either::B(FutErr(err.into())), - }) - )) + self.request.upgrade(); + self.request.set_header(header::UPGRADE, "websocket"); + self.request.set_header(header::CONNECTION, "upgrade"); + self.request.set_header(header::SEC_WEBSOCKET_VERSION, "13"); + self.request.with_connector(self.conn.clone()); + + if let Some(protocols) = self.protocols.take() { + self.request.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str()); + } + let request = match self.request.finish() { + Ok(req) => req, + Err(err) => return ClientHandshake::error(err.into()), + }; + + if request.uri().host().is_none() { + return ClientHandshake::error(ClientError::InvalidUrl) + } + if let Some(scheme) = request.uri().scheme_part() { + if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { + return ClientHandshake::error(ClientError::InvalidUrl) + } + } else { + return ClientHandshake::error(ClientError::InvalidUrl) + } + + // start handshake + ClientHandshake::new(request, self.max_size) + } } } -struct WsInner { - conn: Connection, - writer: HttpClientWriter, - parser: HttpResponseParser, - parser_buf: BytesMut, +struct Inner { + tx: UnboundedSender, + rx: PayloadHelper, closed: bool, - error_sent: bool, } -struct WsHandshake { - inner: Option, - request: ClientRequest, - sent: bool, +pub struct ClientHandshake { + request: Option, + tx: Option>, key: String, + error: Option, + max_size: usize, } -impl WsHandshake { - fn new(conn: Connection, mut request: ClientRequest) -> WsHandshake { +impl ClientHandshake { + fn new(mut request: ClientRequest, max_size: usize) -> ClientHandshake + { // Generate a random key for the `Sec-WebSocket-Key` header. // a base64-encoded (see Section 4 of [RFC4648]) value that, // when decoded, is 16 bytes in length (RFC 6455) @@ -236,155 +267,169 @@ impl WsHandshake { let key = base64::encode(&sec_key); request.headers_mut().insert( - HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), + header::SEC_WEBSOCKET_KEY, HeaderValue::try_from(key.as_str()).unwrap()); - let inner = WsInner { - conn: conn, - writer: HttpClientWriter::new(SharedBytes::default()), - parser: HttpResponseParser::default(), - parser_buf: BytesMut::new(), - closed: false, - error_sent: false, - }; + let (tx, rx) = unbounded(); + request.set_body(Body::Streaming( + Box::new(rx.map_err(|_| io::Error::new( + io::ErrorKind::Other, "disconnected").into())))); - WsHandshake { - key: key, - inner: Some(inner), - request: request, - sent: false, + ClientHandshake { + key, + max_size, + request: Some(request.send()), + tx: Some(tx), + error: None, + } + } + + fn error(err: ClientError) -> ClientHandshake { + ClientHandshake { + key: String::new(), + request: None, + tx: None, + error: Some(err), + max_size: 0 } } } -impl Future for WsHandshake { - type Item = (WsClientReader, WsClientWriter); - type Error = WsClientError; +impl Future for ClientHandshake { + type Item = (ClientReader, ClientWriter); + type Error = ClientError; fn poll(&mut self) -> Poll { - let mut inner = self.inner.take().unwrap(); - - if !self.sent { - self.sent = true; - inner.writer.start(&mut self.request)?; - } - if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) { - return Err(err.into()) + if let Some(err) = self.error.take() { + return Err(err) } - match inner.parser.parse(&mut inner.conn, &mut inner.parser_buf) { - Ok(Async::Ready(resp)) => { - // verify response - if resp.status() != StatusCode::SWITCHING_PROTOCOLS { - return Err(WsClientError::InvalidResponseStatus) - } - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - return Err(WsClientError::InvalidUpgradeHeader) - } - // Check for "CONNECTION" header - let has_hdr = if let Some(conn) = resp.headers().get(header::CONNECTION) { - if let Ok(s) = conn.to_str() { - s.to_lowercase().contains("upgrade") - } else { false } - } else { false }; - if !has_hdr { - return Err(WsClientError::InvalidConnectionHeader) - } - - let match_key = if let Some(key) = resp.headers().get( - HeaderName::try_from("SEC-WEBSOCKET-ACCEPT").unwrap()) - { - // field is constructed by concatenating /key/ - // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) - const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - let mut sha1 = Sha1::new(); - sha1.update(self.key.as_ref()); - sha1.update(WS_GUID); - key.as_bytes() == base64::encode(&sha1.digest().bytes()).as_bytes() - } else { - false - }; - if !match_key { - return Err(WsClientError::InvalidChallengeResponse) - } - - let inner = Rc::new(UnsafeCell::new(inner)); - Ok(Async::Ready( - (WsClientReader{inner: Rc::clone(&inner)}, - WsClientWriter{inner: inner}))) + let resp = match self.request.as_mut().unwrap().poll()? { + Async::Ready(response) => { + self.request.take(); + response }, - Ok(Async::NotReady) => { - self.inner = Some(inner); - Ok(Async::NotReady) - }, - Err(err) => Err(err.into()) + Async::NotReady => return Ok(Async::NotReady) + }; + + // verify response + if resp.status() != StatusCode::SWITCHING_PROTOCOLS { + return Err(ClientError::InvalidResponseStatus(resp.status())) } + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = resp.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + trace!("Invalid upgrade header"); + return Err(ClientError::InvalidUpgradeHeader) + } + // Check for "CONNECTION" header + if let Some(conn) = resp.headers().get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_lowercase().contains("upgrade") { + trace!("Invalid connection header: {}", s); + return Err(ClientError::InvalidConnectionHeader(conn.clone())) + } + } else { + trace!("Invalid connection header: {:?}", conn); + return Err(ClientError::InvalidConnectionHeader(conn.clone())) + } + } else { + trace!("Missing connection header"); + return Err(ClientError::MissingConnectionHeader) + } + + if let Some(key) = resp.headers().get(header::SEC_WEBSOCKET_ACCEPT) + { + // field is constructed by concatenating /key/ + // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::new(); + sha1.update(self.key.as_ref()); + sha1.update(WS_GUID); + let encoded = base64::encode(&sha1.digest().bytes()); + if key.as_bytes() != encoded.as_bytes() { + trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, key); + return Err(ClientError::InvalidChallengeResponse(encoded, key.clone())); + } + } else { + trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(ClientError::MissingWebSocketAcceptHeader) + }; + + let inner = Inner { + tx: self.tx.take().unwrap(), + rx: PayloadHelper::new(resp), + closed: false, + }; + + let inner = Rc::new(UnsafeCell::new(inner)); + Ok(Async::Ready( + (ClientReader{inner: Rc::clone(&inner), max_size: self.max_size}, + ClientWriter{inner}))) } } -pub struct WsClientReader { - inner: Rc> +pub struct ClientReader { + inner: Rc>, + max_size: usize, } -impl fmt::Debug for WsClientReader { +impl fmt::Debug for ClientReader { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "WsClientReader()") + write!(f, "ws::ClientReader()") } } -impl WsClientReader { +impl ClientReader { #[inline] - fn as_mut(&mut self) -> &mut WsInner { + fn as_mut(&mut self) -> &mut Inner { unsafe{ &mut *self.inner.get() } } } -impl Stream for WsClientReader { +impl Stream for ClientReader { type Item = Message; - type Error = WsClientError; + type Error = ProtocolError; fn poll(&mut self) -> Poll, Self::Error> { + let max_size = self.max_size; let inner = self.as_mut(); - let mut done = false; - - match utils::read_from_io(&mut inner.conn, &mut inner.parser_buf) { - Ok(Async::Ready(0)) => { - done = true; - inner.closed = true; - }, - Ok(Async::Ready(_)) | Ok(Async::NotReady) => (), - Err(err) => - return Err(err.into()) + if inner.closed { + return Ok(Async::Ready(None)) } - // write - let _ = inner.writer.poll_completed(&mut inner.conn, false); - // read - match Frame::parse(&mut inner.parser_buf, false) { - Ok(Some(frame)) => { - // trace!("WsFrame {}", frame); - let (_finished, opcode, payload) = frame.unpack(); + match Frame::parse(&mut inner.rx, false, max_size) { + Ok(Async::Ready(Some(frame))) => { + let (finished, opcode, payload) = frame.unpack(); + + // continuation is not supported + if !finished { + inner.closed = true; + return Err(ProtocolError::NoContinuation) + } match opcode { OpCode::Continue => unimplemented!(), - OpCode::Bad => - Ok(Async::Ready(Some(Message::Error))), + OpCode::Bad => { + inner.closed = true; + Err(ProtocolError::BadOpCode) + }, OpCode::Close => { inner.closed = true; - inner.error_sent = true; - Ok(Async::Ready(Some(Message::Closed))) + let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; + Ok(Async::Ready(Some(Message::Close(CloseCode::from(code))))) }, OpCode::Ping => Ok(Async::Ready(Some( @@ -401,53 +446,42 @@ impl Stream for WsClientReader { match String::from_utf8(tmp) { Ok(s) => Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => - Ok(Async::Ready(Some(Message::Error))), + Err(_) => { + inner.closed = true; + Err(ProtocolError::BadEncoding) + } } } } } - Ok(None) => { - if done { - Ok(Async::Ready(None)) - } else if inner.closed { - if !inner.error_sent { - inner.error_sent = true; - Ok(Async::Ready(Some(Message::Closed))) - } else { - Ok(Async::Ready(None)) - } - } else { - Ok(Async::NotReady) - } - }, - Err(err) => { + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(e) => { inner.closed = true; - inner.error_sent = true; - Err(err.into()) + Err(e) } } } } -pub struct WsClientWriter { - inner: Rc> +pub struct ClientWriter { + inner: Rc> } -impl WsClientWriter { +impl ClientWriter { #[inline] - fn as_mut(&mut self) -> &mut WsInner { + fn as_mut(&mut self) -> &mut Inner { unsafe{ &mut *self.inner.get() } } } -impl WsClientWriter { +impl ClientWriter { /// Write payload #[inline] - fn write(&mut self, data: Binary) { + fn write(&mut self, mut data: Binary) { if !self.as_mut().closed { - let _ = self.as_mut().writer.write(data); + let _ = self.as_mut().tx.unbounded_send(data.take()); } else { warn!("Trying to write to disconnected response"); } @@ -455,7 +489,7 @@ impl WsClientWriter { /// Send text frame #[inline] - pub fn text>(&mut self, text: T) { + pub fn text>(&mut self, text: T) { self.write(Frame::message(text.into(), OpCode::Text, true, true)); } diff --git a/src/ws/context.rs b/src/ws/context.rs index b9214b749..4b0775f6a 100644 --- a/src/ws/context.rs +++ b/src/ws/context.rs @@ -18,7 +18,7 @@ use ws::frame::Frame; use ws::proto::{OpCode, CloseCode}; -/// Http actor execution context +/// `WebSockets` actor execution context pub struct WebsocketContext where A: Actor>, { inner: ContextImpl
, @@ -112,6 +112,7 @@ impl WebsocketContext where A: Actor { } let stream = self.stream.as_mut().unwrap(); stream.push(ContextFrame::Chunk(Some(data))); + self.inner.modify(); } else { warn!("Trying to write to disconnected response"); } @@ -131,7 +132,7 @@ impl WebsocketContext where A: Actor { /// Send text frame #[inline] - pub fn text>(&mut self, text: T) { + pub fn text>(&mut self, text: T) { self.write(Frame::message(text.into(), OpCode::Text, true, false)); } @@ -179,6 +180,7 @@ impl WebsocketContext where A: Actor { self.stream = Some(SmallVec::new()); } self.stream.as_mut().map(|s| s.push(frame)); + self.inner.modify(); } /// Handle of the running future diff --git a/src/ws/frame.rs b/src/ws/frame.rs index 612fe2f0a..96162b5c6 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -1,11 +1,15 @@ use std::{fmt, mem}; -use std::io::{Error, ErrorKind}; use std::iter::FromIterator; -use bytes::{BytesMut, BufMut}; +use bytes::{Bytes, BytesMut, BufMut}; use byteorder::{ByteOrder, BigEndian, NetworkEndian}; +use futures::{Async, Poll, Stream}; use rand; use body::Binary; +use error::{PayloadError}; +use payload::PayloadHelper; + +use ws::ProtocolError; use ws::proto::{OpCode, CloseCode}; use ws::mask::apply_mask; @@ -48,14 +52,16 @@ impl Frame { } /// Parse the input stream into a frame. - pub fn parse(buf: &mut BytesMut, server: bool) -> Result, Error> { + pub fn parse(pl: &mut PayloadHelper, server: bool, max_size: usize) + -> Poll, ProtocolError> + where S: Stream + { let mut idx = 2; - let mut size = buf.len(); - - if size < 2 { - return Ok(None) - } - size -= 2; + let buf = match pl.copy(2)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; let first = buf[0]; let second = buf[1]; let finished = first & 0x80 != 0; @@ -63,11 +69,9 @@ impl Frame { // check masking let masked = second & 0x80 != 0; if !masked && server { - return Err(Error::new( - ErrorKind::Other, "Received an unmasked frame from client")) + return Err(ProtocolError::UnmaskedFrame) } else if masked && !server { - return Err(Error::new( - ErrorKind::Other, "Received a masked frame from server")) + return Err(ProtocolError::MaskedFrame) } let rsv1 = first & 0x40 != 0; @@ -77,70 +81,69 @@ impl Frame { let len = second & 0x7F; let length = if len == 126 { - if size < 2 { - return Ok(None) - } + let buf = match pl.copy(4)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; let len = NetworkEndian::read_uint(&buf[idx..], 2) as usize; - size -= 2; idx += 2; len } else if len == 127 { - if size < 8 { - return Ok(None) - } + let buf = match pl.copy(10)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; let len = NetworkEndian::read_uint(&buf[idx..], 8) as usize; - size -= 8; idx += 8; len } else { len as usize }; + // check for max allowed size + if length > max_size { + return Err(ProtocolError::Overflow) + } + let mask = if server { - if size < 4 { - return Ok(None) - } else { - let mut mask_bytes = [0u8; 4]; - size -= 4; - mask_bytes.copy_from_slice(&buf[idx..idx+4]); - idx += 4; - Some(mask_bytes) - } + let buf = match pl.copy(idx + 4)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; + + let mut mask_bytes = [0u8; 4]; + mask_bytes.copy_from_slice(&buf[idx..idx+4]); + idx += 4; + Some(mask_bytes) } else { None }; - if size < length { - return Ok(None) - } + let mut data = match pl.readexactly(idx + length)? { + Async::Ready(Some(buf)) => buf, + Async::Ready(None) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + }; // get body - buf.split_to(idx); - let mut data = if length > 0 { - buf.split_to(length) - } else { - BytesMut::new() - }; + data.split_to(idx); // Disallow bad opcode if let OpCode::Bad = opcode { - return Err( - Error::new( - ErrorKind::Other, - format!("Encountered invalid opcode: {}", first & 0x0F))) + return Err(ProtocolError::InvalidOpcode(first & 0x0F)) } // control frames must have length <= 125 match opcode { OpCode::Ping | OpCode::Pong if length > 125 => { - return Err( - Error::new( - ErrorKind::Other, - format!("Rejected WebSocket handshake.Received control frame with length: {}.", length))) + return Err(ProtocolError::InvalidLength(length)) } OpCode::Close if length > 125 => { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Some(Frame::default())) + return Ok(Async::Ready(Some(Frame::default()))) } _ => () } @@ -150,14 +153,8 @@ impl Frame { apply_mask(&mut data, mask); } - Ok(Some(Frame { - finished: finished, - rsv1: rsv1, - rsv2: rsv2, - rsv3: rsv3, - opcode: opcode, - payload: data.into(), - })) + Ok(Async::Ready(Some(Frame { + finished, rsv1, rsv2, rsv3, opcode, payload: data.into() }))) } /// Generate binary representation @@ -191,7 +188,7 @@ impl Frame { unsafe{buf.advance_mut(2)}; buf } else { - let mut buf = BytesMut::with_capacity(p_len + 8); + let mut buf = BytesMut::with_capacity(p_len + 10); buf.put_slice(&[one, two | 127]); { let buf_mut = unsafe{buf.bytes_mut()}; @@ -258,13 +255,33 @@ impl fmt::Display for Frame { #[cfg(test)] mod tests { use super::*; + use futures::stream::once; + + fn is_none(frm: Poll, ProtocolError>) -> bool { + match frm { + Ok(Async::Ready(None)) => true, + _ => false, + } + } + + fn extract(frm: Poll, ProtocolError>) -> Frame { + match frm { + Ok(Async::Ready(Some(frame))) => frame, + _ => panic!("error"), + } + } #[test] fn test_parse() { + let mut buf = PayloadHelper::new( + once(Ok(BytesMut::from(&[0b00000001u8, 0b00000001u8][..]).freeze()))); + assert!(is_none(Frame::parse(&mut buf, false, 1024))); + let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); - assert!(Frame::parse(&mut buf, false).unwrap().is_none()); buf.extend(b"1"); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + + let frame = extract(Frame::parse(&mut buf, false, 1024)); println!("FRAME: {}", frame); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); @@ -273,8 +290,10 @@ mod tests { #[test] fn test_parse_length0() { - let mut buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let buf = BytesMut::from(&[0b00000001u8, 0b00000000u8][..]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert!(frame.payload.is_empty()); @@ -282,12 +301,16 @@ mod tests { #[test] fn test_parse_length2() { + let buf = BytesMut::from(&[0b00000001u8, 126u8][..]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + assert!(is_none(Frame::parse(&mut buf, false, 1024))); + let mut buf = BytesMut::from(&[0b00000001u8, 126u8][..]); - assert!(Frame::parse(&mut buf, false).unwrap().is_none()); buf.extend(&[0u8, 4u8][..]); buf.extend(b"1234"); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -295,12 +318,16 @@ mod tests { #[test] fn test_parse_length4() { + let buf = BytesMut::from(&[0b00000001u8, 127u8][..]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + assert!(is_none(Frame::parse(&mut buf, false, 1024))); + let mut buf = BytesMut::from(&[0b00000001u8, 127u8][..]); - assert!(Frame::parse(&mut buf, false).unwrap().is_none()); buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); buf.extend(b"1234"); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1234"[..]); @@ -311,10 +338,11 @@ mod tests { let mut buf = BytesMut::from(&[0b00000001u8, 0b10000001u8][..]); buf.extend(b"0001"); buf.extend(b"1"); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(Frame::parse(&mut buf, false).is_err()); + assert!(Frame::parse(&mut buf, false, 1024).is_err()); - let frame = Frame::parse(&mut buf, true).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, true, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload, vec![1u8].into()); @@ -324,15 +352,30 @@ mod tests { fn test_parse_frame_no_mask() { let mut buf = BytesMut::from(&[0b00000001u8, 0b00000001u8][..]); buf.extend(&[1u8]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); - assert!(Frame::parse(&mut buf, true).is_err()); + assert!(Frame::parse(&mut buf, true, 1024).is_err()); - let frame = Frame::parse(&mut buf, false).unwrap().unwrap(); + let frame = extract(Frame::parse(&mut buf, false, 1024)); assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload, vec![1u8].into()); } + #[test] + fn test_parse_frame_max_size() { + let mut buf = BytesMut::from(&[0b00000001u8, 0b00000010u8][..]); + buf.extend(&[1u8, 1u8]); + let mut buf = PayloadHelper::new(once(Ok(buf.freeze()))); + + assert!(Frame::parse(&mut buf, true, 1).is_err()); + + if let Err(ProtocolError::Overflow) = Frame::parse(&mut buf, false, 0) { + } else { + panic!("error"); + } + } + #[test] fn test_ping_frame() { let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false); diff --git a/src/ws/mod.rs b/src/ws/mod.rs index d9bf0f103..bf31189ca 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -23,9 +23,8 @@ //! type Context = ws::WebsocketContext; //! } //! -//! // Define Handler for ws::Message message -//! impl Handler for Ws { -//! type Result = (); +//! // Handler for ws::Message messages +//! impl StreamHandler for Ws { //! //! fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { //! match msg { @@ -43,17 +42,20 @@ //! # .finish(); //! # } //! ``` -use bytes::BytesMut; +use bytes::Bytes; use http::{Method, StatusCode, header}; use futures::{Async, Poll, Stream}; +use byteorder::{ByteOrder, NetworkEndian}; -use actix::{Actor, AsyncContext, Handler}; +use actix::{Actor, AsyncContext, StreamHandler}; use body::Binary; -use payload::ReadAny; -use error::{Error, WsHandshakeError}; +use payload::PayloadHelper; +use error::{Error, PayloadError, ResponseError}; +use httpmessage::HttpMessage; use httprequest::HttpRequest; use httpresponse::{ConnectionType, HttpResponse, HttpResponseBuilder}; +use httpcodes::{HttpBadRequest, HttpMethodNotAllowed}; mod frame; mod proto; @@ -65,13 +67,108 @@ use self::frame::Frame; use self::proto::{hash_key, OpCode}; pub use self::proto::CloseCode; pub use self::context::WebsocketContext; -pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsClientFuture}; +pub use self::client::{Client, ClientError, + ClientReader, ClientWriter, ClientHandshake}; -const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; -const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; -const SEC_WEBSOCKET_VERSION: &str = "SEC-WEBSOCKET-VERSION"; -// const SEC_WEBSOCKET_PROTOCOL: &'static str = "SEC-WEBSOCKET-PROTOCOL"; +#[allow(deprecated)] +pub use self::client::{WsClient, WsClientError, + WsClientReader, WsClientWriter, WsClientHandshake}; +/// Backward compatibility +#[doc(hidden)] +#[deprecated(since="0.4.2", note="please use `ws::ProtocolError` instead")] +pub type WsError = ProtocolError; +#[doc(hidden)] +#[deprecated(since="0.4.2", note="please use `ws::HandshakeError` instead")] +pub type WsHandshakeError = HandshakeError; + +/// Websocket errors +#[derive(Fail, Debug)] +pub enum ProtocolError { + /// Received an unmasked frame from client + #[fail(display="Received an unmasked frame from client")] + UnmaskedFrame, + /// Received a masked frame from server + #[fail(display="Received a masked frame from server")] + MaskedFrame, + /// Encountered invalid opcode + #[fail(display="Invalid opcode: {}", _0)] + InvalidOpcode(u8), + /// Invalid control frame length + #[fail(display="Invalid control frame length: {}", _0)] + InvalidLength(usize), + /// Bad web socket op code + #[fail(display="Bad web socket op code")] + BadOpCode, + /// A payload reached size limit. + #[fail(display="A payload reached size limit.")] + Overflow, + /// Continuation is not supproted + #[fail(display="Continuation is not supproted.")] + NoContinuation, + /// Bad utf-8 encoding + #[fail(display="Bad utf-8 encoding.")] + BadEncoding, + /// Payload error + #[fail(display="Payload error: {}", _0)] + Payload(#[cause] PayloadError), +} + +impl ResponseError for ProtocolError {} + +impl From for ProtocolError { + fn from(err: PayloadError) -> ProtocolError { + ProtocolError::Payload(err) + } +} + +/// Websocket handshake errors +#[derive(Fail, PartialEq, Debug)] +pub enum HandshakeError { + /// Only get method is allowed + #[fail(display="Method not allowed")] + GetMethodRequired, + /// Upgrade header if not set to websocket + #[fail(display="Websocket upgrade is expected")] + NoWebsocketUpgrade, + /// Connection header is not set to upgrade + #[fail(display="Connection upgrade is expected")] + NoConnectionUpgrade, + /// Websocket version header is not set + #[fail(display="Websocket version header is required")] + NoVersionHeader, + /// Unsupported websocket version + #[fail(display="Unsupported version")] + UnsupportedVersion, + /// Websocket key is not set or wrong + #[fail(display="Unknown websocket key")] + BadWebsocketKey, +} + +impl ResponseError for HandshakeError { + + fn error_response(&self) -> HttpResponse { + match *self { + HandshakeError::GetMethodRequired => { + HttpMethodNotAllowed + .build() + .header(header::ALLOW, "GET") + .finish() + .unwrap() + } + HandshakeError::NoWebsocketUpgrade => + HttpBadRequest.with_reason("No WebSocket UPGRADE header found"), + HandshakeError::NoConnectionUpgrade => + HttpBadRequest.with_reason("No CONNECTION upgrade"), + HandshakeError::NoVersionHeader => + HttpBadRequest.with_reason("Websocket version header is required"), + HandshakeError::UnsupportedVersion => + HttpBadRequest.with_reason("Unsupported version"), + HandshakeError::BadWebsocketKey => + HttpBadRequest.with_reason("Handshake error"), + } + } +} /// `WebSocket` Message #[derive(Debug, PartialEq, Message)] @@ -80,21 +177,19 @@ pub enum Message { Binary(Binary), Ping(String), Pong(String), - Close, - Closed, - Error + Close(CloseCode), } /// Do websocket handshake and start actor -pub fn start(mut req: HttpRequest, actor: A) -> Result - where A: Actor> + Handler, +pub fn start(req: HttpRequest, actor: A) -> Result + where A: Actor> + StreamHandler, S: 'static { let mut resp = handshake(&req)?; - let stream = WsStream::new(req.payload_mut().readany()); + let stream = WsStream::new(req.clone()); let mut ctx = WebsocketContext::new(req, actor); - ctx.add_message_stream(stream); + ctx.add_stream(stream); Ok(resp.body(ctx)?) } @@ -107,10 +202,10 @@ pub fn start(mut req: HttpRequest, actor: A) -> Result(req: &HttpRequest) -> Result { +pub fn handshake(req: &HttpRequest) -> Result { // WebSocket accepts only GET if *req.method() != Method::GET { - return Err(WsHandshakeError::GetMethodRequired) + return Err(HandshakeError::GetMethodRequired) } // Check for "UPGRADE" to websocket header @@ -124,35 +219,35 @@ pub fn handshake(req: &HttpRequest) -> Result(req: &HttpRequest) -> Result { + rx: PayloadHelper, closed: bool, - error_sent: bool, + max_size: usize, } -impl WsStream { - pub fn new(payload: ReadAny) -> WsStream { - WsStream { rx: payload, - buf: BytesMut::new(), +impl WsStream where S: Stream { + /// Create new websocket frames stream + pub fn new(stream: S) -> WsStream { + WsStream { rx: PayloadHelper::new(stream), closed: false, - error_sent: false } + max_size: 65_536, + } + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self } } -impl Stream for WsStream { +impl Stream for WsStream where S: Stream { type Item = Message; - type Error = (); + type Error = ProtocolError; fn poll(&mut self) -> Poll, Self::Error> { - let mut done = false; - - if !self.closed { - loop { - match self.rx.poll() { - Ok(Async::Ready(Some(chunk))) => { - self.buf.extend_from_slice(&chunk) - } - Ok(Async::Ready(None)) => { - done = true; - self.closed = true; - break; - } - Ok(Async::NotReady) => break, - Err(_) => { - self.closed = true; - break; - } - } - } + if self.closed { + return Ok(Async::Ready(None)) } - loop { - match Frame::parse(&mut self.buf, true) { - Ok(Some(frame)) => { - // trace!("WsFrame {}", frame); - let (_finished, opcode, payload) = frame.unpack(); + match Frame::parse(&mut self.rx, true, self.max_size) { + Ok(Async::Ready(Some(frame))) => { + let (finished, opcode, payload) = frame.unpack(); - match opcode { - OpCode::Continue => continue, - OpCode::Bad => - return Ok(Async::Ready(Some(Message::Error))), - OpCode::Close => { - self.closed = true; - self.error_sent = true; - return Ok(Async::Ready(Some(Message::Closed))) - }, - OpCode::Ping => - return Ok(Async::Ready(Some( - Message::Ping( - String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Pong => - return Ok(Async::Ready(Some( - Message::Pong( - String::from_utf8_lossy(payload.as_ref()).into())))), - OpCode::Binary => - return Ok(Async::Ready(Some(Message::Binary(payload)))), - OpCode::Text => { - let tmp = Vec::from(payload.as_ref()); - match String::from_utf8(tmp) { - Ok(s) => - return Ok(Async::Ready(Some(Message::Text(s)))), - Err(_) => - return Ok(Async::Ready(Some(Message::Error))), + // continuation is not supported + if !finished { + self.closed = true; + return Err(ProtocolError::NoContinuation) + } + + match opcode { + OpCode::Continue => unimplemented!(), + OpCode::Bad => { + self.closed = true; + Err(ProtocolError::BadOpCode) + } + OpCode::Close => { + self.closed = true; + let code = NetworkEndian::read_uint(payload.as_ref(), 2) as u16; + Ok(Async::Ready( + Some(Message::Close(CloseCode::from(code))))) + }, + OpCode::Ping => + Ok(Async::Ready(Some( + Message::Ping( + String::from_utf8_lossy(payload.as_ref()).into())))), + OpCode::Pong => + Ok(Async::Ready(Some( + Message::Pong(String::from_utf8_lossy(payload.as_ref()).into())))), + OpCode::Binary => + Ok(Async::Ready(Some(Message::Binary(payload)))), + OpCode::Text => { + let tmp = Vec::from(payload.as_ref()); + match String::from_utf8(tmp) { + Ok(s) => + Ok(Async::Ready(Some(Message::Text(s)))), + Err(_) => { + self.closed = true; + Err(ProtocolError::BadEncoding) } } } } - Ok(None) => { - if done { - return Ok(Async::Ready(None)) - } else if self.closed { - if !self.error_sent { - self.error_sent = true; - return Ok(Async::Ready(Some(Message::Closed))) - } else { - return Ok(Async::Ready(None)) - } - } else { - return Ok(Async::NotReady) - } - }, - Err(_) => { - self.closed = true; - self.error_sent = true; - return Ok(Async::Ready(Some(Message::Error))); - } + } + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(e) => { + self.closed = true; + Err(e) } } } @@ -278,25 +357,25 @@ mod tests { fn test_handshake() { let req = HttpRequest::new(Method::POST, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None); - assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); + assert_eq!(HandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None); - assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); + assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, header::HeaderValue::from_static("test")); let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - assert_eq!(WsHandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); + assert_eq!(HandshakeError::NoWebsocketUpgrade, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, header::HeaderValue::from_static("websocket")); let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - assert_eq!(WsHandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap()); + assert_eq!(HandshakeError::NoConnectionUpgrade, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, @@ -305,42 +384,58 @@ mod tests { header::HeaderValue::from_static("upgrade")); let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - assert_eq!(WsHandshakeError::NoVersionHeader, handshake(&req).err().unwrap()); + assert_eq!(HandshakeError::NoVersionHeader, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, header::HeaderValue::from_static("websocket")); headers.insert(header::CONNECTION, header::HeaderValue::from_static("upgrade")); - headers.insert(SEC_WEBSOCKET_VERSION, + headers.insert(header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("5")); let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - assert_eq!(WsHandshakeError::UnsupportedVersion, handshake(&req).err().unwrap()); + assert_eq!(HandshakeError::UnsupportedVersion, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, header::HeaderValue::from_static("websocket")); headers.insert(header::CONNECTION, header::HeaderValue::from_static("upgrade")); - headers.insert(SEC_WEBSOCKET_VERSION, + headers.insert(header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13")); let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); - assert_eq!(WsHandshakeError::BadWebsocketKey, handshake(&req).err().unwrap()); + assert_eq!(HandshakeError::BadWebsocketKey, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, header::HeaderValue::from_static("websocket")); headers.insert(header::CONNECTION, header::HeaderValue::from_static("upgrade")); - headers.insert(SEC_WEBSOCKET_VERSION, + headers.insert(header::SEC_WEBSOCKET_VERSION, header::HeaderValue::from_static("13")); - headers.insert(SEC_WEBSOCKET_KEY, + headers.insert(header::SEC_WEBSOCKET_KEY, header::HeaderValue::from_static("13")); let req = HttpRequest::new(Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None); assert_eq!(StatusCode::SWITCHING_PROTOCOLS, handshake(&req).unwrap().finish().unwrap().status()); } + + #[test] + fn test_wserror_http_response() { + let resp: HttpResponse = HandshakeError::GetMethodRequired.error_response(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let resp: HttpResponse = HandshakeError::NoWebsocketUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = HandshakeError::NoConnectionUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = HandshakeError::NoVersionHeader.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = HandshakeError::UnsupportedVersion.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: HttpResponse = HandshakeError::BadWebsocketKey.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } } diff --git a/tests/test_client.rs b/tests/test_client.rs index 02a18f40b..aaa3fa786 100644 --- a/tests/test_client.rs +++ b/tests/test_client.rs @@ -2,8 +2,14 @@ extern crate actix; extern crate actix_web; extern crate bytes; extern crate futures; +extern crate flate2; + +use std::io::Read; use bytes::Bytes; +use futures::Future; +use futures::stream::once; +use flate2::read::GzDecoder; use actix_web::*; @@ -57,3 +63,171 @@ fn test_simple() { let bytes = srv.execute(response.body()).unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); } + +#[test] +fn test_no_decompress() { + let mut srv = test::TestServer::new( + |app| app.handler(|_| httpcodes::HTTPOk.build().body(STR))); + + let request = srv.get().disable_decompress().finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + + // POST + let request = srv.post().disable_decompress().finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + + let bytes = srv.execute(response.body()).unwrap(); + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_client_gzip_encoding() { + let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { + req.body() + .and_then(|bytes: Bytes| { + Ok(httpcodes::HTTPOk + .build() + .content_encoding(headers::ContentEncoding::Deflate) + .body(bytes)) + }).responder()} + )); + + // client request + let request = srv.post() + .content_encoding(headers::ContentEncoding::Gzip) + .body(STR).unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_client_gzip_encoding_large() { + let data = STR.repeat(10); + + let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { + req.body() + .and_then(|bytes: Bytes| { + Ok(httpcodes::HTTPOk + .build() + .content_encoding(headers::ContentEncoding::Deflate) + .body(bytes)) + }).responder()} + )); + + // client request + let request = srv.post() + .content_encoding(headers::ContentEncoding::Gzip) + .body(data.clone()).unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from(data)); +} + +#[test] +fn test_client_brotli_encoding() { + let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { + req.body() + .and_then(|bytes: Bytes| { + Ok(httpcodes::HTTPOk + .build() + .content_encoding(headers::ContentEncoding::Deflate) + .body(bytes)) + }).responder()} + )); + + // client request + let request = srv.client(Method::POST, "/") + .content_encoding(headers::ContentEncoding::Br) + .body(STR).unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_client_deflate_encoding() { + let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { + req.body() + .and_then(|bytes: Bytes| { + Ok(httpcodes::HTTPOk + .build() + .content_encoding(headers::ContentEncoding::Br) + .body(bytes)) + }).responder()} + )); + + // client request + let request = srv.post() + .content_encoding(headers::ContentEncoding::Deflate) + .body(STR).unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_client_streaming_explicit() { + let mut srv = test::TestServer::new( + |app| app.handler( + |req: HttpRequest| req.body() + .map_err(Error::from) + .and_then(|body| { + Ok(httpcodes::HTTPOk.build() + .chunked() + .content_encoding(headers::ContentEncoding::Identity) + .body(body)?)}) + .responder())); + + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + + let request = srv.get().body(Body::Streaming(Box::new(body))).unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_body_streaming_implicit() { + let mut srv = test::TestServer::new( + |app| app.handler(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + httpcodes::HTTPOk.build() + .content_encoding(headers::ContentEncoding::Gzip) + .body(Body::Streaming(Box::new(body)))})); + + let request = srv.get().finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} diff --git a/tests/test_server.rs b/tests/test_server.rs index 2cbeba8fc..92a876b5c 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -93,6 +93,39 @@ fn test_start() { } } +#[test] +#[cfg(unix)] +fn test_shutdown() { + let _ = test::TestServer::unused_addr(); + let (tx, rx) = mpsc::channel(); + + thread::spawn(move || { + let sys = System::new("test"); + let srv = HttpServer::new( + || vec![Application::new() + .resource("/", |r| r.method(Method::GET).h(httpcodes::HTTPOk))]); + + let srv = srv.bind("127.0.0.1:0").unwrap(); + let addr = srv.addrs()[0]; + let srv_addr = srv.shutdown_timeout(1).start(); + let _ = tx.send((addr, srv_addr)); + sys.run(); + }); + let (addr, srv_addr) = rx.recv().unwrap(); + + let mut sys = System::new("test-server"); + + { + let req = client::ClientRequest::get(format!("http://{}/", addr).as_str()).finish().unwrap(); + let response = sys.run_until_complete(req.send()).unwrap(); + srv_addr.do_send(server::StopServer{graceful: true}); + assert!(response.status().is_success()); + } + + thread::sleep(time::Duration::from_millis(1000)); + assert!(net::TcpStream::connect(addr).is_err()); +} + #[test] fn test_simple() { let mut srv = test::TestServer::new(|app| app.handler(httpcodes::HTTPOk)); @@ -101,6 +134,44 @@ fn test_simple() { assert!(response.status().is_success()); } +#[test] +fn test_headers() { + let data = STR.repeat(10); + let srv_data = Arc::new(data.clone()); + let mut srv = test::TestServer::new( + move |app| { + let data = srv_data.clone(); + app.handler(move |_| { + let mut builder = httpcodes::HTTPOk.build(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST "); + } + builder.body(data.as_ref())}) + }); + + let request = srv.get().finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + assert_eq!(bytes, Bytes::from(data)); +} + #[test] fn test_body() { let mut srv = test::TestServer::new( @@ -123,7 +194,7 @@ fn test_body_gzip() { .content_encoding(headers::ContentEncoding::Gzip) .body(STR))); - let request = srv.get().finish().unwrap(); + let request = srv.get().disable_decompress().finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -138,7 +209,34 @@ fn test_body_gzip() { } #[test] -fn test_body_streaming_implicit() { +fn test_body_gzip_large() { + let data = STR.repeat(10); + let srv_data = Arc::new(data.clone()); + + let mut srv = test::TestServer::new( + move |app| { + let data = srv_data.clone(); + app.handler( + move |_| httpcodes::HTTPOk.build() + .content_encoding(headers::ContentEncoding::Gzip) + .body(data.as_ref()))}); + + let request = srv.get().disable_decompress().finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.execute(response.body()).unwrap(); + + // decode + let mut e = GzDecoder::new(&bytes[..]); + let mut dec = Vec::new(); + e.read_to_end(&mut dec).unwrap(); + assert_eq!(Bytes::from(dec), Bytes::from(data)); +} + +#[test] +fn test_body_chunked_implicit() { let mut srv = test::TestServer::new( |app| app.handler(|_| { let body = once(Ok(Bytes::from_static(STR.as_ref()))); @@ -146,7 +244,7 @@ fn test_body_streaming_implicit() { .content_encoding(headers::ContentEncoding::Gzip) .body(Body::Streaming(Box::new(body)))})); - let request = srv.get().finish().unwrap(); + let request = srv.get().disable_decompress().finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -169,7 +267,7 @@ fn test_body_br_streaming() { .content_encoding(headers::ContentEncoding::Br) .body(Body::Streaming(Box::new(body)))})); - let request = srv.get().finish().unwrap(); + let request = srv.get().disable_decompress().finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -252,6 +350,7 @@ fn test_body_length() { let body = once(Ok(Bytes::from_static(STR.as_ref()))); httpcodes::HTTPOk.build() .content_length(STR.len() as u64) + .content_encoding(headers::ContentEncoding::Identity) .body(Body::Streaming(Box::new(body)))})); let request = srv.get().finish().unwrap(); @@ -264,7 +363,7 @@ fn test_body_length() { } #[test] -fn test_body_streaming_explicit() { +fn test_body_chunked_explicit() { let mut srv = test::TestServer::new( |app| app.handler(|_| { let body = once(Ok(Bytes::from_static(STR.as_ref()))); @@ -273,7 +372,7 @@ fn test_body_streaming_explicit() { .content_encoding(headers::ContentEncoding::Gzip) .body(Body::Streaming(Box::new(body)))})); - let request = srv.get().finish().unwrap(); + let request = srv.get().disable_decompress().finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -297,7 +396,7 @@ fn test_body_deflate() { .body(STR))); // client request - let request = srv.get().finish().unwrap(); + let request = srv.get().disable_decompress().finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -321,7 +420,7 @@ fn test_body_brotli() { .body(STR))); // client request - let request = srv.get().finish().unwrap(); + let request = srv.get().disable_decompress().finish().unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); @@ -364,31 +463,32 @@ fn test_gzip_encoding() { } #[test] -fn test_client_gzip_encoding() { +fn test_gzip_encoding_large() { + let data = STR.repeat(10); let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { req.body() .and_then(|bytes: Bytes| { Ok(httpcodes::HTTPOk .build() - .content_encoding(headers::ContentEncoding::Deflate) + .content_encoding(headers::ContentEncoding::Identity) .body(bytes)) }).responder()} )); // client request + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + let request = srv.post() - .content_encoding(headers::ContentEncoding::Gzip) - .body(STR).unwrap(); + .header(header::CONTENT_ENCODING, "gzip") + .body(enc.clone()).unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); // read response let bytes = srv.execute(response.body()).unwrap(); - - let mut e = DeflateDecoder::new(Vec::new()); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + assert_eq!(bytes, Bytes::from(data)); } #[test] @@ -420,32 +520,32 @@ fn test_deflate_encoding() { } #[test] -fn test_client_deflate_encoding() { +fn test_deflate_encoding_large() { + let data = STR.repeat(10); let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { req.body() .and_then(|bytes: Bytes| { Ok(httpcodes::HTTPOk .build() - .content_encoding(headers::ContentEncoding::Br) + .content_encoding(headers::ContentEncoding::Identity) .body(bytes)) }).responder()} )); + let mut e = DeflateEncoder::new(Vec::new(), Compression::default()); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + // client request let request = srv.post() - .content_encoding(headers::ContentEncoding::Deflate) - .body(STR).unwrap(); + .header(header::CONTENT_ENCODING, "deflate") + .body(enc).unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); // read response let bytes = srv.execute(response.body()).unwrap(); - - // decode brotli - let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + assert_eq!(bytes, Bytes::from(data)); } #[test] @@ -477,32 +577,32 @@ fn test_brotli_encoding() { } #[test] -fn test_client_brotli_encoding() { +fn test_brotli_encoding_large() { + let data = STR.repeat(10); let mut srv = test::TestServer::new(|app| app.handler(|req: HttpRequest| { req.body() .and_then(|bytes: Bytes| { Ok(httpcodes::HTTPOk .build() - .content_encoding(headers::ContentEncoding::Deflate) + .content_encoding(headers::ContentEncoding::Identity) .body(bytes)) }).responder()} )); + let mut e = BrotliEncoder::new(Vec::new(), 5); + e.write_all(data.as_ref()).unwrap(); + let enc = e.finish().unwrap(); + // client request - let request = srv.client(Method::POST, "/") - .content_encoding(headers::ContentEncoding::Br) - .body(STR).unwrap(); + let request = srv.post() + .header(header::CONTENT_ENCODING, "br") + .body(enc).unwrap(); let response = srv.execute(request.send()).unwrap(); assert!(response.status().is_success()); // read response let bytes = srv.execute(response.body()).unwrap(); - - // decode brotli - let mut e = DeflateDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); + assert_eq!(bytes, Bytes::from(data)); } #[test] @@ -545,30 +645,6 @@ fn test_h2() { // assert_eq!(_res.unwrap(), Bytes::from_static(STR.as_ref())); } -#[test] -fn test_client_streaming_explicit() { - let mut srv = test::TestServer::new( - |app| app.handler( - |req: HttpRequest| req.body() - .map_err(Error::from) - .and_then(|body| { - Ok(httpcodes::HTTPOk.build() - .chunked() - .content_encoding(headers::ContentEncoding::Identity) - .body(body)?)}) - .responder())); - - let body = once(Ok(Bytes::from_static(STR.as_ref()))); - - let request = srv.get().body(Body::Streaming(Box::new(body))).unwrap(); - let response = srv.execute(request.send()).unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = srv.execute(response.body()).unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - #[test] fn test_application() { let mut srv = test::TestServer::with_factory( diff --git a/tests/test_ws.rs b/tests/test_ws.rs index ac7119914..edda3f64b 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -16,14 +16,14 @@ impl Actor for Ws { type Context = ws::WebsocketContext; } -impl Handler for Ws { - type Result = (); +impl StreamHandler for Ws { fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) { match msg { ws::Message::Ping(msg) => ctx.pong(&msg), ws::Message::Text(text) => ctx.text(text), ws::Message::Binary(bin) => ctx.binary(bin), + ws::Message::Close(reason) => ctx.close(reason, ""), _ => (), } } @@ -49,5 +49,5 @@ fn test_simple() { writer.close(ws::CloseCode::Normal, ""); let (item, _) = srv.execute(reader.into_future()).unwrap(); - assert!(item.is_none()); + assert_eq!(item, Some(ws::Message::Close(ws::CloseCode::Normal))); } diff --git a/tools/wsload/src/wsclient.rs b/tools/wsload/src/wsclient.rs index e6438c634..2d8db7fb7 100644 --- a/tools/wsload/src/wsclient.rs +++ b/tools/wsload/src/wsclient.rs @@ -19,7 +19,7 @@ use futures::Future; use rand::{thread_rng, Rng}; use actix::prelude::*; -use actix_web::ws::{Message, WsClientError, WsClient, WsClientWriter}; +use actix_web::ws; fn main() { @@ -71,21 +71,21 @@ fn main() { let perf = perf_counters.clone(); let addr = Arbiter::new(format!("test {}", t)); - addr.send(actix::msgs::Execute::new(move || -> Result<(), ()> { + addr.do_send(actix::msgs::Execute::new(move || -> Result<(), ()> { let mut reps = report; for _ in 0..concurrency { let pl2 = pl.clone(); let perf2 = perf.clone(); Arbiter::handle().spawn( - WsClient::new(&ws).connect().unwrap() + ws::Client::new(&ws).connect() .map_err(|e| { println!("Error: {}", e); - Arbiter::system().send(actix::msgs::SystemExit(0)); + Arbiter::system().do_send(actix::msgs::SystemExit(0)); () }) .map(move |(reader, writer)| { - let addr: SyncAddress<_> = ChatClient::create(move |ctx| { + let addr: Addr = ChatClient::create(move |ctx| { ChatClient::add_stream(reader, ctx); ChatClient{conn: writer, payload: pl2, @@ -114,7 +114,7 @@ fn parse_u64_default(input: Option<&str>, default: u64) -> u64 { } struct ChatClient{ - conn: WsClientWriter, + conn: ws::ClientWriter, payload: Arc, ts: u64, bin: bool, @@ -133,9 +133,9 @@ impl Actor for ChatClient { } } - fn stopping(&mut self, _: &mut Context) -> bool { - Arbiter::system().send(actix::msgs::SystemExit(0)); - true + fn stopping(&mut self, _: &mut Context) -> Running { + Arbiter::system().do_send(actix::msgs::SystemExit(0)); + Running::Stop } } @@ -171,15 +171,15 @@ impl ChatClient { } /// Handle server websocket messages -impl StreamHandler for ChatClient { +impl StreamHandler for ChatClient { fn finished(&mut self, ctx: &mut Context) { ctx.stop() } - fn handle(&mut self, msg: Message, ctx: &mut Context) { + fn handle(&mut self, msg: ws::Message, ctx: &mut Context) { match msg { - Message::Text(txt) => { + ws::Message::Text(txt) => { if txt == self.payload.as_ref().as_str() { self.perf_counters.register_request(); self.perf_counters.register_latency(time::precise_time_ns() - self.ts);