diff --git a/actix-http/.appveyor.yml b/actix-http/.appveyor.yml new file mode 100644 index 000000000..780fdd6b5 --- /dev/null +++ b/actix-http/.appveyor.yml @@ -0,0 +1,41 @@ +environment: + global: + PROJECT_NAME: actix-http + matrix: + # Stable channel + - TARGET: i686-pc-windows-msvc + CHANNEL: stable + - TARGET: x86_64-pc-windows-gnu + CHANNEL: stable + - TARGET: x86_64-pc-windows-msvc + CHANNEL: stable + # Nightly channel + - TARGET: i686-pc-windows-msvc + CHANNEL: nightly + - TARGET: x86_64-pc-windows-gnu + CHANNEL: nightly + - TARGET: x86_64-pc-windows-msvc + CHANNEL: nightly + +# Install Rust and Cargo +# (Based on from https://github.com/rust-lang/libc/blob/master/appveyor.yml) +install: + - ps: >- + If ($Env:TARGET -eq 'x86_64-pc-windows-gnu') { + $Env:PATH += ';C:\msys64\mingw64\bin' + } ElseIf ($Env:TARGET -eq 'i686-pc-windows-gnu') { + $Env:PATH += ';C:\MinGW\bin' + } + - curl -sSf -o rustup-init.exe https://win.rustup.rs + - rustup-init.exe --default-host %TARGET% --default-toolchain %CHANNEL% -y + - set PATH=%PATH%;C:\Users\appveyor\.cargo\bin + - rustc -Vv + - cargo -V + +# 'cargo test' takes care of building for us, so disable Appveyor's build stage. +build: false + +# Equivalent to Travis' `script` phase +test_script: + - cargo clean + - cargo test diff --git a/actix-http/.gitignore b/actix-http/.gitignore new file mode 100644 index 000000000..42d0755dd --- /dev/null +++ b/actix-http/.gitignore @@ -0,0 +1,14 @@ +Cargo.lock +target/ +guide/build/ +/gh-pages + +*.so +*.out +*.pyc +*.pid +*.sock +*~ + +# These are backup files generated by rustfmt +**/*.rs.bk diff --git a/actix-http/.travis.yml b/actix-http/.travis.yml new file mode 100644 index 000000000..02fbd42c7 --- /dev/null +++ b/actix-http/.travis.yml @@ -0,0 +1,52 @@ +language: rust +sudo: required +dist: trusty + +cache: + cargo: true + apt: true + +matrix: + include: + - rust: stable + - rust: beta + - rust: nightly-2019-03-02 + allow_failures: + - rust: nightly-2019-03-02 + +env: + global: + - RUSTFLAGS="-C link-dead-code" + - OPENSSL_VERSION=openssl-1.0.2 + +before_install: + - sudo add-apt-repository -y ppa:0k53d-karl-f830m/openssl + - sudo apt-get update -qq + - sudo apt-get install -y openssl libssl-dev libelf-dev libdw-dev cmake gcc binutils-dev libiberty-dev + +before_cache: | + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-03-02" ]]; then + RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install cargo-tarpaulin + fi + +script: +- cargo clean +- cargo build --all-features +- cargo test --all-features + +# Upload docs +after_success: + - | + if [[ "$TRAVIS_OS_NAME" == "linux" && "$TRAVIS_PULL_REQUEST" = "false" && "$TRAVIS_BRANCH" == "master" && "$TRAVIS_RUST_VERSION" == "stable" ]]; then + cargo doc --no-deps && + echo "" > target/doc/index.html && + git clone https://github.com/davisp/ghp-import.git && + ./ghp-import/ghp_import.py -n -p -f -m "Documentation upload" -r https://"$GH_TOKEN"@github.com/"$TRAVIS_REPO_SLUG.git" target/doc && + echo "Uploaded documentation" + fi + - | + if [[ "$TRAVIS_RUST_VERSION" == "nightly-2019-03-02" ]]; then + taskset -c 0 cargo tarpaulin --features="ssl" --out Xml + bash <(curl -s https://codecov.io/bash) + echo "Uploaded code coverage" + fi diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md new file mode 100644 index 000000000..74fa4a22e --- /dev/null +++ b/actix-http/CHANGES.md @@ -0,0 +1,5 @@ +# Changes + +## [0.1.0] - 2019-01-x + +* Initial impl diff --git a/actix-http/CODE_OF_CONDUCT.md b/actix-http/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..599b28c0d --- /dev/null +++ b/actix-http/CODE_OF_CONDUCT.md @@ -0,0 +1,46 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at fafhrd91@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml new file mode 100644 index 000000000..da11a5853 --- /dev/null +++ b/actix-http/Cargo.toml @@ -0,0 +1,94 @@ +[package] +name = "actix-http" +version = "0.1.0" +authors = ["Nikolay Kim "] +description = "Actix http" +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-http.git" +documentation = "https://docs.rs/actix-http/" +categories = ["network-programming", "asynchronous", + "web-programming::http-server", + "web-programming::websocket"] +license = "MIT/Apache-2.0" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +edition = "2018" + +[package.metadata.docs.rs] +features = ["ssl", "fail", "cookie"] + +[badges] +travis-ci = { repository = "actix/actix-http", branch = "master" } +# appveyor = { repository = "fafhrd91/actix-http-b1qsn" } +codecov = { repository = "actix/actix-http", branch = "master", service = "github" } + +[lib] +name = "actix_http" +path = "src/lib.rs" + +[features] +default = [] + +# openssl +ssl = ["openssl", "actix-connect/ssl"] + +# cookies integration +cookies = ["cookie"] + +# failure integration. actix does not use failure anymore +fail = ["failure"] + +[dependencies] +actix-service = "0.3.4" +actix-codec = "0.1.1" +actix-connect = "0.1.0" +actix-utils = "0.3.4" +actix-server-config = "0.1.0" + +base64 = "0.10" +backtrace = "0.3" +bitflags = "1.0" +bytes = "0.4" +byteorder = "1.2" +derive_more = "0.14" +encoding = "0.2" +futures = "0.1" +hashbrown = "0.1.8" +h2 = "0.1.16" +http = "0.1.16" +httparse = "1.3" +indexmap = "1.0" +lazy_static = "1.0" +language-tags = "0.2" +log = "0.4" +mime = "0.3" +percent-encoding = "1.0" +rand = "0.6" +regex = "1.0" +serde = "1.0" +serde_json = "1.0" +sha1 = "0.6" +slab = "0.4" +serde_urlencoded = "0.5.3" +time = "0.1" +tokio-tcp = "0.1.3" +tokio-timer = "0.2" +tokio-current-thread = "0.1" +trust-dns-resolver = { version="0.11.0-alpha.2", default-features = false } + +# optional deps +cookie = { version="0.11", features=["percent-encode"], optional = true } +failure = { version = "0.1.5", optional = true } +openssl = { version="0.10", optional = true } + +[dev-dependencies] +actix-rt = "0.2.1" +actix-server = { version = "0.4.0", features=["ssl"] } +actix-connect = { version = "0.1.0", features=["ssl"] } +actix-http-test = { path="test-server", features=["ssl"] } + +env_logger = "0.6" +serde_derive = "1.0" +openssl = { version="0.10" } +tokio-tcp = "0.1" \ No newline at end of file diff --git a/actix-http/LICENSE-APACHE b/actix-http/LICENSE-APACHE new file mode 100644 index 000000000..6cdf2d16c --- /dev/null +++ b/actix-http/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017-NOW Nikolay Kim + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/actix-http/LICENSE-MIT b/actix-http/LICENSE-MIT new file mode 100644 index 000000000..0f80296ae --- /dev/null +++ b/actix-http/LICENSE-MIT @@ -0,0 +1,25 @@ +Copyright (c) 2017 Nikolay Kim + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/actix-http/README.md b/actix-http/README.md new file mode 100644 index 000000000..467e67a9d --- /dev/null +++ b/actix-http/README.md @@ -0,0 +1,46 @@ +# Actix http [![Build Status](https://travis-ci.org/actix/actix-http.svg?branch=master)](https://travis-ci.org/actix/actix-http) [![codecov](https://codecov.io/gh/actix/actix-http/branch/master/graph/badge.svg)](https://codecov.io/gh/actix/actix-http) [![crates.io](https://meritbadge.herokuapp.com/actix-web)](https://crates.io/crates/actix-web) [![Join the chat at https://gitter.im/actix/actix](https://badges.gitter.im/actix/actix.svg)](https://gitter.im/actix/actix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +Actix http + +## Documentation & community resources + +* [User Guide](https://actix.rs/docs/) +* [API Documentation](https://docs.rs/actix-http/) +* [Chat on gitter](https://gitter.im/actix/actix) +* Cargo package: [actix-http](https://crates.io/crates/actix-http) +* Minimum supported Rust version: 1.26 or later + +## Example + +```rust +// see examples/framed_hello.rs for complete list of used crates. +extern crate actix_http; +use actix_http::{h1, Response, ServiceConfig}; + +fn main() { + Server::new().bind("framed_hello", "127.0.0.1:8080", || { + IntoFramed::new(|| h1::Codec::new(ServiceConfig::default())) // <- create h1 codec + .and_then(TakeItem::new().map_err(|_| ())) // <- read one request + .and_then(|(_req, _framed): (_, Framed<_, _>)| { // <- send response and close conn + SendResponse::send(_framed, Response::Ok().body("Hello world!")) + .map_err(|_| ()) + .map(|_| ()) + }) + }).unwrap().run(); +} +``` + +## License + +This project is licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or [http://opensource.org/licenses/MIT](http://opensource.org/licenses/MIT)) + +at your option. + +## Code of Conduct + +Contribution to the actix-http crate is organized under the terms of the +Contributor Covenant, the maintainer of actix-http, @fafhrd91, promises to +intervene to uphold that code of conduct. diff --git a/actix-http/examples/client.rs b/actix-http/examples/client.rs new file mode 100644 index 000000000..7f5f8c91a --- /dev/null +++ b/actix-http/examples/client.rs @@ -0,0 +1,33 @@ +use actix_http::{client, Error}; +use actix_rt::System; +use bytes::BytesMut; +use futures::{future::lazy, Future, Stream}; + +fn main() -> Result<(), Error> { + std::env::set_var("RUST_LOG", "actix_http=trace"); + env_logger::init(); + + System::new("test").block_on(lazy(|| { + let mut connector = client::Connector::new().service(); + + client::ClientRequest::get("https://www.rust-lang.org/") // <- Create request builder + .header("User-Agent", "Actix-web") + .finish() + .unwrap() + .send(&mut connector) // <- Send http request + .from_err() + .and_then(|response| { + // <- server http response + println!("Response: {:?}", response); + + // read response body + response + .from_err() + .fold(BytesMut::new(), move |mut acc, chunk| { + acc.extend_from_slice(&chunk); + Ok::<_, Error>(acc) + }) + .map(|body| println!("Downloaded: {:?} bytes", body.len())) + }) + })) +} diff --git a/actix-http/examples/echo.rs b/actix-http/examples/echo.rs new file mode 100644 index 000000000..c36292c44 --- /dev/null +++ b/actix-http/examples/echo.rs @@ -0,0 +1,37 @@ +use std::{env, io}; + +use actix_http::{error::PayloadError, HttpService, Request, Response}; +use actix_server::Server; +use bytes::BytesMut; +use futures::{Future, Stream}; +use http::header::HeaderValue; +use log::info; + +fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "echo=info"); + env_logger::init(); + + Server::build() + .bind("echo", "127.0.0.1:8080", || { + HttpService::build() + .client_timeout(1000) + .client_disconnect(1000) + .finish(|mut req: Request| { + req.take_payload() + .fold(BytesMut::new(), move |mut body, chunk| { + body.extend_from_slice(&chunk); + Ok::<_, PayloadError>(body) + }) + .and_then(|bytes| { + info!("request body: {:?}", bytes); + let mut res = Response::Ok(); + res.header( + "x-head", + HeaderValue::from_static("dummy value!"), + ); + Ok(res.body(bytes)) + }) + }) + })? + .run() +} diff --git a/actix-http/examples/echo2.rs b/actix-http/examples/echo2.rs new file mode 100644 index 000000000..b239796b4 --- /dev/null +++ b/actix-http/examples/echo2.rs @@ -0,0 +1,34 @@ +use std::{env, io}; + +use actix_http::http::HeaderValue; +use actix_http::{error::PayloadError, Error, HttpService, Request, Response}; +use actix_server::Server; +use bytes::BytesMut; +use futures::{Future, Stream}; +use log::info; + +fn handle_request(mut req: Request) -> impl Future { + req.take_payload() + .fold(BytesMut::new(), move |mut body, chunk| { + body.extend_from_slice(&chunk); + Ok::<_, PayloadError>(body) + }) + .from_err() + .and_then(|bytes| { + info!("request body: {:?}", bytes); + let mut res = Response::Ok(); + res.header("x-head", HeaderValue::from_static("dummy value!")); + Ok(res.body(bytes)) + }) +} + +fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "echo=info"); + env_logger::init(); + + Server::build() + .bind("echo", "127.0.0.1:8080", || { + HttpService::build().finish(|_req: Request| handle_request(_req)) + })? + .run() +} diff --git a/actix-http/examples/framed_hello.rs b/actix-http/examples/framed_hello.rs new file mode 100644 index 000000000..7d4c13d34 --- /dev/null +++ b/actix-http/examples/framed_hello.rs @@ -0,0 +1,28 @@ +use std::{env, io}; + +use actix_codec::Framed; +use actix_http::{h1, Response, SendResponse, ServiceConfig}; +use actix_server::{Io, Server}; +use actix_service::{fn_service, NewService}; +use actix_utils::framed::IntoFramed; +use actix_utils::stream::TakeItem; +use futures::Future; +use tokio_tcp::TcpStream; + +fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "framed_hello=info"); + env_logger::init(); + + Server::build() + .bind("framed_hello", "127.0.0.1:8080", || { + fn_service(|io: Io| Ok(io.into_parts().0)) + .and_then(IntoFramed::new(|| h1::Codec::new(ServiceConfig::default()))) + .and_then(TakeItem::new().map_err(|_| ())) + .and_then(|(_req, _framed): (_, Framed<_, _>)| { + SendResponse::send(_framed, Response::Ok().body("Hello world!")) + .map_err(|_| ()) + .map(|_| ()) + }) + })? + .run() +} diff --git a/actix-http/examples/hello-world.rs b/actix-http/examples/hello-world.rs new file mode 100644 index 000000000..6e3820390 --- /dev/null +++ b/actix-http/examples/hello-world.rs @@ -0,0 +1,26 @@ +use std::{env, io}; + +use actix_http::{HttpService, Response}; +use actix_server::Server; +use futures::future; +use http::header::HeaderValue; +use log::info; + +fn main() -> io::Result<()> { + env::set_var("RUST_LOG", "hello_world=info"); + env_logger::init(); + + Server::build() + .bind("hello-world", "127.0.0.1:8080", || { + HttpService::build() + .client_timeout(1000) + .client_disconnect(1000) + .finish(|_req| { + info!("{:?}", _req); + let mut res = Response::Ok(); + res.header("x-head", HeaderValue::from_static("dummy value!")); + future::ok::<_, ()>(res.body("Hello world!")) + }) + })? + .run() +} diff --git a/actix-http/rustfmt.toml b/actix-http/rustfmt.toml new file mode 100644 index 000000000..5fcaaca0f --- /dev/null +++ b/actix-http/rustfmt.toml @@ -0,0 +1,5 @@ +max_width = 89 +reorder_imports = true +#wrap_comments = true +#fn_args_density = "Compressed" +#use_small_heuristics = false diff --git a/actix-http/src/body.rs b/actix-http/src/body.rs new file mode 100644 index 000000000..e1399e6b4 --- /dev/null +++ b/actix-http/src/body.rs @@ -0,0 +1,465 @@ +use std::marker::PhantomData; +use std::{fmt, mem}; + +use bytes::{Bytes, BytesMut}; +use futures::{Async, Poll, Stream}; + +use crate::error::Error; + +#[derive(Debug, PartialEq, Copy, Clone)] +/// Different type of body +pub enum BodyLength { + None, + Empty, + Sized(usize), + Sized64(u64), + Stream, +} + +impl BodyLength { + pub fn is_eof(&self) -> bool { + match self { + BodyLength::None + | BodyLength::Empty + | BodyLength::Sized(0) + | BodyLength::Sized64(0) => true, + _ => false, + } + } +} + +/// Type that provides this trait can be streamed to a peer. +pub trait MessageBody { + fn length(&self) -> BodyLength; + + fn poll_next(&mut self) -> Poll, Error>; +} + +impl MessageBody for () { + fn length(&self) -> BodyLength { + BodyLength::Empty + } + + fn poll_next(&mut self) -> Poll, Error> { + Ok(Async::Ready(None)) + } +} + +impl MessageBody for Box { + fn length(&self) -> BodyLength { + self.as_ref().length() + } + + fn poll_next(&mut self) -> Poll, Error> { + self.as_mut().poll_next() + } +} + +pub enum ResponseBody { + Body(B), + Other(Body), +} + +impl ResponseBody { + pub fn into_body(self) -> ResponseBody { + match self { + ResponseBody::Body(b) => ResponseBody::Other(b), + ResponseBody::Other(b) => ResponseBody::Other(b), + } + } +} + +impl ResponseBody { + pub fn take_body(&mut self) -> ResponseBody { + std::mem::replace(self, ResponseBody::Other(Body::None)) + } +} + +impl ResponseBody { + pub fn as_ref(&self) -> Option<&B> { + if let ResponseBody::Body(ref b) = self { + Some(b) + } else { + None + } + } +} + +impl MessageBody for ResponseBody { + fn length(&self) -> BodyLength { + match self { + ResponseBody::Body(ref body) => body.length(), + ResponseBody::Other(ref body) => body.length(), + } + } + + fn poll_next(&mut self) -> Poll, Error> { + match self { + ResponseBody::Body(ref mut body) => body.poll_next(), + ResponseBody::Other(ref mut body) => body.poll_next(), + } + } +} + +impl Stream for ResponseBody { + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Self::Error> { + self.poll_next() + } +} + +/// Represents various types of http message body. +pub enum Body { + /// Empty response. `Content-Length` header is not set. + None, + /// Zero sized response body. `Content-Length` header is set to `0`. + Empty, + /// Specific response body. + Bytes(Bytes), + /// Generic message body. + Message(Box), +} + +impl Body { + /// Create body from slice (copy) + pub fn from_slice(s: &[u8]) -> Body { + Body::Bytes(Bytes::from(s)) + } + + /// Create body from generic message body. + pub fn from_message(body: B) -> Body { + Body::Message(Box::new(body)) + } +} + +impl MessageBody for Body { + fn length(&self) -> BodyLength { + match self { + Body::None => BodyLength::None, + Body::Empty => BodyLength::Empty, + Body::Bytes(ref bin) => BodyLength::Sized(bin.len()), + Body::Message(ref body) => body.length(), + } + } + + fn poll_next(&mut self) -> Poll, Error> { + match self { + Body::None => Ok(Async::Ready(None)), + Body::Empty => Ok(Async::Ready(None)), + Body::Bytes(ref mut bin) => { + let len = bin.len(); + if len == 0 { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(bin.split_to(len)))) + } + } + Body::Message(ref mut body) => body.poll_next(), + } + } +} + +impl PartialEq for Body { + fn eq(&self, other: &Body) -> bool { + match *self { + Body::None => match *other { + Body::None => true, + _ => false, + }, + Body::Empty => match *other { + Body::Empty => true, + _ => false, + }, + Body::Bytes(ref b) => match *other { + Body::Bytes(ref b2) => b == b2, + _ => false, + }, + Body::Message(_) => false, + } + } +} + +impl fmt::Debug for Body { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Body::None => write!(f, "Body::None"), + Body::Empty => write!(f, "Body::Zero"), + Body::Bytes(ref b) => write!(f, "Body::Bytes({:?})", b), + Body::Message(_) => write!(f, "Body::Message(_)"), + } + } +} + +impl From<&'static str> for Body { + fn from(s: &'static str) -> Body { + Body::Bytes(Bytes::from_static(s.as_ref())) + } +} + +impl From<&'static [u8]> for Body { + fn from(s: &'static [u8]) -> Body { + Body::Bytes(Bytes::from_static(s)) + } +} + +impl From> for Body { + fn from(vec: Vec) -> Body { + Body::Bytes(Bytes::from(vec)) + } +} + +impl From for Body { + fn from(s: String) -> Body { + s.into_bytes().into() + } +} + +impl<'a> From<&'a String> for Body { + fn from(s: &'a String) -> Body { + Body::Bytes(Bytes::from(AsRef::<[u8]>::as_ref(&s))) + } +} + +impl From for Body { + fn from(s: Bytes) -> Body { + Body::Bytes(s) + } +} + +impl From for Body { + fn from(s: BytesMut) -> Body { + Body::Bytes(s.freeze()) + } +} + +impl MessageBody for Bytes { + fn length(&self) -> BodyLength { + BodyLength::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(mem::replace(self, Bytes::new())))) + } + } +} + +impl MessageBody for BytesMut { + fn length(&self) -> BodyLength { + BodyLength::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some( + mem::replace(self, BytesMut::new()).freeze(), + ))) + } + } +} + +impl MessageBody for &'static str { + fn length(&self) -> BodyLength { + BodyLength::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from_static( + mem::replace(self, "").as_ref(), + )))) + } + } +} + +impl MessageBody for &'static [u8] { + fn length(&self) -> BodyLength { + BodyLength::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from_static(mem::replace( + self, b"", + ))))) + } + } +} + +impl MessageBody for Vec { + fn length(&self) -> BodyLength { + BodyLength::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from(mem::replace( + self, + Vec::new(), + ))))) + } + } +} + +impl MessageBody for String { + fn length(&self) -> BodyLength { + BodyLength::Sized(self.len()) + } + + fn poll_next(&mut self) -> Poll, Error> { + if self.is_empty() { + Ok(Async::Ready(None)) + } else { + Ok(Async::Ready(Some(Bytes::from( + mem::replace(self, String::new()).into_bytes(), + )))) + } + } +} + +/// Type represent streaming body. +/// Response does not contain `content-length` header and appropriate transfer encoding is used. +pub struct BodyStream { + stream: S, + _t: PhantomData, +} + +impl BodyStream +where + S: Stream, + E: Into, +{ + pub fn new(stream: S) -> Self { + BodyStream { + stream, + _t: PhantomData, + } + } +} + +impl MessageBody for BodyStream +where + S: Stream, + E: Into, +{ + fn length(&self) -> BodyLength { + BodyLength::Stream + } + + fn poll_next(&mut self) -> Poll, Error> { + self.stream.poll().map_err(|e| e.into()) + } +} + +/// Type represent streaming body. This body implementation should be used +/// if total size of stream is known. Data get sent as is without using transfer encoding. +pub struct SizedStream { + size: usize, + stream: S, +} + +impl SizedStream +where + S: Stream, +{ + pub fn new(size: usize, stream: S) -> Self { + SizedStream { size, stream } + } +} + +impl MessageBody for SizedStream +where + S: Stream, +{ + fn length(&self) -> BodyLength { + BodyLength::Sized(self.size) + } + + fn poll_next(&mut self) -> Poll, Error> { + self.stream.poll() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + impl Body { + pub(crate) fn get_ref(&self) -> &[u8] { + match *self { + Body::Bytes(ref bin) => &bin, + _ => panic!(), + } + } + } + + impl ResponseBody { + pub(crate) fn get_ref(&self) -> &[u8] { + match *self { + ResponseBody::Body(ref b) => b.get_ref(), + ResponseBody::Other(ref b) => b.get_ref(), + } + } + } + + #[test] + fn test_static_str() { + assert_eq!(Body::from("").length(), BodyLength::Sized(0)); + assert_eq!(Body::from("test").length(), BodyLength::Sized(4)); + assert_eq!(Body::from("test").get_ref(), b"test"); + } + + #[test] + fn test_static_bytes() { + assert_eq!(Body::from(b"test".as_ref()).length(), BodyLength::Sized(4)); + assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); + assert_eq!( + Body::from_slice(b"test".as_ref()).length(), + BodyLength::Sized(4) + ); + assert_eq!(Body::from_slice(b"test".as_ref()).get_ref(), b"test"); + } + + #[test] + fn test_vec() { + assert_eq!(Body::from(Vec::from("test")).length(), BodyLength::Sized(4)); + assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); + } + + #[test] + fn test_bytes() { + assert_eq!( + Body::from(Bytes::from("test")).length(), + BodyLength::Sized(4) + ); + assert_eq!(Body::from(Bytes::from("test")).get_ref(), b"test"); + } + + #[test] + fn test_string() { + let b = "test".to_owned(); + assert_eq!(Body::from(b.clone()).length(), BodyLength::Sized(4)); + assert_eq!(Body::from(b.clone()).get_ref(), b"test"); + assert_eq!(Body::from(&b).length(), BodyLength::Sized(4)); + assert_eq!(Body::from(&b).get_ref(), b"test"); + } + + #[test] + fn test_bytes_mut() { + let b = BytesMut::from("test"); + assert_eq!(Body::from(b.clone()).length(), BodyLength::Sized(4)); + assert_eq!(Body::from(b).get_ref(), b"test"); + } +} diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs new file mode 100644 index 000000000..2f7466a90 --- /dev/null +++ b/actix-http/src/builder.rs @@ -0,0 +1,141 @@ +use std::fmt::Debug; +use std::marker::PhantomData; + +use actix_server_config::ServerConfig as SrvConfig; +use actix_service::{IntoNewService, NewService}; + +use crate::body::MessageBody; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::request::Request; +use crate::response::Response; + +use crate::h1::H1Service; +use crate::h2::H2Service; +use crate::service::HttpService; + +/// A http service builder +/// +/// This type can be used to construct an instance of `http service` through a +/// builder-like pattern. +pub struct HttpServiceBuilder { + keep_alive: KeepAlive, + client_timeout: u64, + client_disconnect: u64, + _t: PhantomData<(T, S)>, +} + +impl HttpServiceBuilder +where + S: NewService, + S::Error: Debug + 'static, + S::Service: 'static, +{ + /// Create instance of `ServiceConfigBuilder` + pub fn new() -> HttpServiceBuilder { + HttpServiceBuilder { + keep_alive: KeepAlive::Timeout(5), + client_timeout: 5000, + client_disconnect: 0, + _t: PhantomData, + } + } + + /// Set server keep-alive setting. + /// + /// By default keep alive is set to a 5 seconds. + pub fn keep_alive>(mut self, val: U) -> Self { + self.keep_alive = val.into(); + self + } + + /// Set server client timeout in milliseconds for first request. + /// + /// Defines a timeout for reading client request header. If a client does not transmit + /// the entire set headers within this time, the request is terminated with + /// the 408 (Request Time-out) error. + /// + /// To disable timeout set value to 0. + /// + /// By default client timeout is set to 5000 milliseconds. + pub fn client_timeout(mut self, val: u64) -> Self { + self.client_timeout = val; + self + } + + /// Set server connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the request get dropped. This timeout affects secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 0. + pub fn client_disconnect(mut self, val: u64) -> Self { + self.client_disconnect = val; + self + } + + // #[cfg(feature = "ssl")] + // /// Configure alpn protocols for SslAcceptorBuilder. + // pub fn configure_openssl( + // builder: &mut openssl::ssl::SslAcceptorBuilder, + // ) -> io::Result<()> { + // let protos: &[u8] = b"\x02h2"; + // builder.set_alpn_select_callback(|_, protos| { + // const H2: &[u8] = b"\x02h2"; + // if protos.windows(3).any(|window| window == H2) { + // Ok(b"h2") + // } else { + // Err(openssl::ssl::AlpnError::NOACK) + // } + // }); + // builder.set_alpn_protos(&protos)?; + + // Ok(()) + // } + + /// Finish service configuration and create *http service* for HTTP/1 protocol. + pub fn h1(self, service: F) -> H1Service + where + B: MessageBody + 'static, + F: IntoNewService, + S::Response: Into>, + { + let cfg = ServiceConfig::new( + self.keep_alive, + self.client_timeout, + self.client_disconnect, + ); + H1Service::with_config(cfg, service.into_new_service()) + } + + /// Finish service configuration and create *http service* for HTTP/2 protocol. + pub fn h2(self, service: F) -> H2Service + where + B: MessageBody + 'static, + F: IntoNewService, + S::Response: Into>, + { + let cfg = ServiceConfig::new( + self.keep_alive, + self.client_timeout, + self.client_disconnect, + ); + H2Service::with_config(cfg, service.into_new_service()) + } + + /// Finish service configuration and create `HttpService` instance. + pub fn finish(self, service: F) -> HttpService + where + B: MessageBody + 'static, + F: IntoNewService, + S::Response: Into>, + { + let cfg = ServiceConfig::new( + self.keep_alive, + self.client_timeout, + self.client_disconnect, + ); + HttpService::with_config(cfg, service.into_new_service()) + } +} diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs new file mode 100644 index 000000000..e8c1201aa --- /dev/null +++ b/actix-http/src/client/connection.rs @@ -0,0 +1,132 @@ +use std::{fmt, time}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use bytes::Bytes; +use futures::Future; +use h2::client::SendRequest; + +use crate::body::MessageBody; +use crate::message::{RequestHead, ResponseHead}; +use crate::payload::Payload; + +use super::error::SendRequestError; +use super::pool::Acquired; +use super::{h1proto, h2proto}; + +pub(crate) enum ConnectionType { + H1(Io), + H2(SendRequest), +} + +pub trait Connection { + type Future: Future; + + /// Send request and body + fn send_request( + self, + head: RequestHead, + body: B, + ) -> Self::Future; +} + +pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { + /// Close connection + fn close(&mut self); + + /// Release connection to the connection pool + fn release(&mut self); +} + +#[doc(hidden)] +/// HTTP client connection +pub struct IoConnection { + io: Option>, + created: time::Instant, + pool: Option>, +} + +impl fmt::Debug for IoConnection +where + T: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self.io { + Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), + Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), + None => write!(f, "Connection(Empty)"), + } + } +} + +impl IoConnection { + pub(crate) fn new( + io: ConnectionType, + created: time::Instant, + pool: Option>, + ) -> Self { + IoConnection { + pool, + created, + io: Some(io), + } + } + + pub(crate) fn into_inner(self) -> (ConnectionType, time::Instant) { + (self.io.unwrap(), self.created) + } +} + +impl Connection for IoConnection +where + T: AsyncRead + AsyncWrite + 'static, +{ + type Future = Box>; + + fn send_request( + mut self, + head: RequestHead, + body: B, + ) -> Self::Future { + match self.io.take().unwrap() { + ConnectionType::H1(io) => Box::new(h1proto::send_request( + io, + head, + body, + self.created, + self.pool, + )), + ConnectionType::H2(io) => Box::new(h2proto::send_request( + io, + head, + body, + self.created, + self.pool, + )), + } + } +} + +#[allow(dead_code)] +pub(crate) enum EitherConnection { + A(IoConnection), + B(IoConnection), +} + +impl Connection for EitherConnection +where + A: AsyncRead + AsyncWrite + 'static, + B: AsyncRead + AsyncWrite + 'static, +{ + type Future = Box>; + + fn send_request( + self, + head: RequestHead, + body: RB, + ) -> Self::Future { + match self { + EitherConnection::A(con) => con.send_request(head, body), + EitherConnection::B(con) => con.send_request(head, body), + } + } +} diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs new file mode 100644 index 000000000..804756ced --- /dev/null +++ b/actix-http/src/client/connector.rs @@ -0,0 +1,442 @@ +use std::fmt; +use std::marker::PhantomData; +use std::time::Duration; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_connect::{ + default_connector, Connect as TcpConnect, Connection as TcpConnection, +}; +use actix_service::{apply_fn, Service, ServiceExt}; +use actix_utils::timeout::{TimeoutError, TimeoutService}; +use http::Uri; +use tokio_tcp::TcpStream; + +use super::connection::Connection; +use super::error::ConnectError; +use super::pool::{ConnectionPool, Protocol}; + +#[cfg(feature = "ssl")] +use openssl::ssl::SslConnector; + +#[cfg(not(feature = "ssl"))] +type SslConnector = (); + +/// Http client connector builde instance. +/// `Connector` type uses builder-like pattern for connector service construction. +pub struct Connector { + connector: T, + timeout: Duration, + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Duration, + limit: usize, + #[allow(dead_code)] + ssl: SslConnector, + _t: PhantomData, +} + +impl Connector<(), ()> { + pub fn new() -> Connector< + impl Service< + Request = TcpConnect, + Response = TcpConnection, + Error = actix_connect::ConnectError, + > + Clone, + TcpStream, + > { + let ssl = { + #[cfg(feature = "ssl")] + { + use log::error; + use openssl::ssl::{SslConnector, SslMethod}; + + let mut ssl = SslConnector::builder(SslMethod::tls()).unwrap(); + let _ = ssl + .set_alpn_protos(b"\x02h2\x08http/1.1") + .map_err(|e| error!("Can not set alpn protocol: {:?}", e)); + ssl.build() + } + #[cfg(not(feature = "ssl"))] + {} + }; + + Connector { + ssl, + connector: default_connector(), + timeout: Duration::from_secs(1), + conn_lifetime: Duration::from_secs(75), + conn_keep_alive: Duration::from_secs(15), + disconnect_timeout: Duration::from_millis(3000), + limit: 100, + _t: PhantomData, + } + } +} + +impl Connector { + /// Use custom connector. + pub fn connector(self, connector: T1) -> Connector + where + U1: AsyncRead + AsyncWrite + fmt::Debug, + T1: Service< + Request = TcpConnect, + Response = TcpConnection, + Error = actix_connect::ConnectError, + > + Clone, + { + Connector { + connector, + timeout: self.timeout, + conn_lifetime: self.conn_lifetime, + conn_keep_alive: self.conn_keep_alive, + disconnect_timeout: self.disconnect_timeout, + limit: self.limit, + ssl: self.ssl, + _t: PhantomData, + } + } +} + +impl Connector +where + U: AsyncRead + AsyncWrite + fmt::Debug + 'static, + T: Service< + Request = TcpConnect, + Response = TcpConnection, + Error = actix_connect::ConnectError, + > + Clone, +{ + /// Connection timeout, i.e. max time to connect to remote host including dns name resolution. + /// Set to 1 second by default. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + #[cfg(feature = "ssl")] + /// Use custom `SslConnector` instance. + pub fn ssl(mut self, connector: SslConnector) -> Self { + self.ssl = connector; + self + } + + /// Set total number of simultaneous connections per type of scheme. + /// + /// If limit is 0, the connector has no limit. + /// The default limit size is 100. + pub fn limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + /// Set keep-alive period for opened connection. + /// + /// Keep-alive period is the period between connection usage. If + /// the delay between repeated usages of the same connection + /// exceeds this period, the connection is closed. + /// Default keep-alive period is 15 seconds. + pub fn conn_keep_alive(mut self, dur: Duration) -> Self { + self.conn_keep_alive = dur; + self + } + + /// Set max lifetime period for connection. + /// + /// Connection lifetime is max lifetime of any opened connection + /// until it is closed regardless of keep-alive period. + /// Default lifetime period is 75 seconds. + pub fn conn_lifetime(mut self, dur: Duration) -> Self { + self.conn_lifetime = dur; + self + } + + /// Set server connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the socket get dropped. This timeout affects only secure connections. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 3000 milliseconds. + pub fn disconnect_timeout(mut self, dur: Duration) -> Self { + self.disconnect_timeout = dur; + self + } + + /// Finish configuration process and create connector service. + pub fn service( + self, + ) -> impl Service + Clone + { + #[cfg(not(feature = "ssl"))] + { + let connector = TimeoutService::new( + self.timeout, + apply_fn(self.connector, |msg: Uri, srv| srv.call(msg.into())) + .map_err(ConnectError::from) + .map(|stream| (stream.into_parts().0, Protocol::Http1)), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + + connect_impl::InnerConnector { + tcp_pool: ConnectionPool::new( + connector, + self.conn_lifetime, + self.conn_keep_alive, + None, + self.limit, + ), + } + } + #[cfg(feature = "ssl")] + { + const H2: &[u8] = b"h2"; + use actix_connect::ssl::OpensslConnector; + + let ssl_service = TimeoutService::new( + self.timeout, + apply_fn(self.connector.clone(), |msg: Uri, srv| srv.call(msg.into())) + .map_err(ConnectError::from) + .and_then( + OpensslConnector::service(self.ssl) + .map_err(ConnectError::from) + .map(|stream| { + let sock = stream.into_parts().0; + let h2 = sock + .get_ref() + .ssl() + .selected_alpn_protocol() + .map(|protos| protos.windows(2).any(|w| w == H2)) + .unwrap_or(false); + if h2 { + (sock, Protocol::Http2) + } else { + (sock, Protocol::Http1) + } + }), + ), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + + let tcp_service = TimeoutService::new( + self.timeout, + apply_fn(self.connector.clone(), |msg: Uri, srv| srv.call(msg.into())) + .map_err(ConnectError::from) + .map(|stream| (stream.into_parts().0, Protocol::Http1)), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + + connect_impl::InnerConnector { + tcp_pool: ConnectionPool::new( + tcp_service, + self.conn_lifetime, + self.conn_keep_alive, + None, + self.limit, + ), + ssl_pool: ConnectionPool::new( + ssl_service, + self.conn_lifetime, + self.conn_keep_alive, + Some(self.disconnect_timeout), + self.limit, + ), + } + } + } +} + +#[cfg(not(feature = "ssl"))] +mod connect_impl { + use futures::future::{err, Either, FutureResult}; + use futures::Poll; + + use super::*; + use crate::client::connection::IoConnection; + + pub(crate) struct InnerConnector + where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, + { + pub(crate) tcp_pool: ConnectionPool, + } + + impl Clone for InnerConnector + where + Io: AsyncRead + AsyncWrite + 'static, + T: Service + + Clone, + { + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + } + } + } + + impl Service for InnerConnector + where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, + { + type Request = Uri; + type Response = IoConnection; + type Error = ConnectError; + type Future = Either< + as Service>::Future, + FutureResult, ConnectError>, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.tcp_pool.poll_ready() + } + + fn call(&mut self, req: Uri) -> Self::Future { + match req.scheme_str() { + Some("https") | Some("wss") => { + Either::B(err(ConnectError::SslIsNotSupported)) + } + _ => Either::A(self.tcp_pool.call(req)), + } + } + } +} + +#[cfg(feature = "ssl")] +mod connect_impl { + use std::marker::PhantomData; + + use futures::future::{Either, FutureResult}; + use futures::{Async, Future, Poll}; + + use super::*; + use crate::client::connection::EitherConnection; + + pub(crate) struct InnerConnector + where + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + T1: Service, + T2: Service, + { + pub(crate) tcp_pool: ConnectionPool, + pub(crate) ssl_pool: ConnectionPool, + } + + impl Clone for InnerConnector + where + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + T1: Service + + Clone, + T2: Service + + Clone, + { + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + ssl_pool: self.ssl_pool.clone(), + } + } + } + + impl Service for InnerConnector + where + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + T1: Service, + T2: Service, + { + type Request = Uri; + type Response = EitherConnection; + type Error = ConnectError; + type Future = Either< + FutureResult, + Either< + InnerConnectorResponseA, + InnerConnectorResponseB, + >, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.tcp_pool.poll_ready() + } + + fn call(&mut self, req: Uri) -> Self::Future { + match req.scheme_str() { + Some("https") | Some("wss") => { + Either::B(Either::B(InnerConnectorResponseB { + fut: self.ssl_pool.call(req), + _t: PhantomData, + })) + } + _ => Either::B(Either::A(InnerConnectorResponseA { + fut: self.tcp_pool.call(req), + _t: PhantomData, + })), + } + } + } + + pub(crate) struct InnerConnectorResponseA + where + Io1: AsyncRead + AsyncWrite + 'static, + T: Service, + { + fut: as Service>::Future, + _t: PhantomData, + } + + impl Future for InnerConnectorResponseA + where + T: Service, + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + { + type Item = EitherConnection; + type Error = ConnectError; + + fn poll(&mut self) -> Poll { + match self.fut.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(res) => Ok(Async::Ready(EitherConnection::A(res))), + } + } + } + + pub(crate) struct InnerConnectorResponseB + where + Io2: AsyncRead + AsyncWrite + 'static, + T: Service, + { + fut: as Service>::Future, + _t: PhantomData, + } + + impl Future for InnerConnectorResponseB + where + T: Service, + Io1: AsyncRead + AsyncWrite + 'static, + Io2: AsyncRead + AsyncWrite + 'static, + { + type Item = EitherConnection; + type Error = ConnectError; + + fn poll(&mut self) -> Poll { + match self.fut.poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(res) => Ok(Async::Ready(EitherConnection::B(res))), + } + } + } +} diff --git a/actix-http/src/client/error.rs b/actix-http/src/client/error.rs new file mode 100644 index 000000000..69ec49585 --- /dev/null +++ b/actix-http/src/client/error.rs @@ -0,0 +1,124 @@ +use std::io; + +use derive_more::{Display, From}; +use trust_dns_resolver::error::ResolveError; + +#[cfg(feature = "ssl")] +use openssl::ssl::{Error as SslError, HandshakeError}; + +use crate::error::{Error, ParseError, ResponseError}; +use crate::http::Error as HttpError; +use crate::response::Response; + +/// A set of errors that can occur while connecting to an HTTP host +#[derive(Debug, Display, From)] +pub enum ConnectError { + /// SSL feature is not enabled + #[display(fmt = "SSL is not supported")] + SslIsNotSupported, + + /// SSL error + #[cfg(feature = "ssl")] + #[display(fmt = "{}", _0)] + SslError(SslError), + + /// Failed to resolve the hostname + #[display(fmt = "Failed resolving hostname: {}", _0)] + Resolver(ResolveError), + + /// No dns records + #[display(fmt = "No dns records found for the input")] + NoRecords, + + /// Http2 error + #[display(fmt = "{}", _0)] + H2(h2::Error), + + /// Connecting took too long + #[display(fmt = "Timeout out while establishing connection")] + Timeout, + + /// Connector has been disconnected + #[display(fmt = "Internal error: connector has been disconnected")] + Disconnected, + + /// Unresolved host name + #[display(fmt = "Connector received `Connect` method with unresolved host")] + Unresolverd, + + /// Connection io error + #[display(fmt = "{}", _0)] + Io(io::Error), +} + +impl From for ConnectError { + fn from(err: actix_connect::ConnectError) -> ConnectError { + match err { + actix_connect::ConnectError::Resolver(e) => ConnectError::Resolver(e), + actix_connect::ConnectError::NoRecords => ConnectError::NoRecords, + actix_connect::ConnectError::InvalidInput => panic!(), + actix_connect::ConnectError::Unresolverd => ConnectError::Unresolverd, + actix_connect::ConnectError::Io(e) => ConnectError::Io(e), + } + } +} + +#[cfg(feature = "ssl")] +impl From> for ConnectError { + fn from(err: HandshakeError) -> ConnectError { + match err { + HandshakeError::SetupFailure(stack) => SslError::from(stack).into(), + HandshakeError::Failure(stream) => stream.into_error().into(), + HandshakeError::WouldBlock(stream) => stream.into_error().into(), + } + } +} + +#[derive(Debug, Display, From)] +pub enum InvalidUrl { + #[display(fmt = "Missing url scheme")] + MissingScheme, + #[display(fmt = "Unknown url scheme")] + UnknownScheme, + #[display(fmt = "Missing host name")] + MissingHost, + #[display(fmt = "Url parse error: {}", _0)] + HttpError(http::Error), +} + +/// A set of errors that can occur during request sending and response reading +#[derive(Debug, Display, From)] +pub enum SendRequestError { + /// Invalid URL + #[display(fmt = "Invalid URL: {}", _0)] + Url(InvalidUrl), + /// Failed to connect to host + #[display(fmt = "Failed to connect to host: {}", _0)] + Connect(ConnectError), + /// Error sending request + Send(io::Error), + /// Error parsing response + Response(ParseError), + /// Http error + #[display(fmt = "{}", _0)] + Http(HttpError), + /// Http2 error + #[display(fmt = "{}", _0)] + H2(h2::Error), + /// Error sending request body + Body(Error), +} + +/// Convert `SendRequestError` to a server `Response` +impl ResponseError for SendRequestError { + fn error_response(&self) -> Response { + match *self { + SendRequestError::Connect(ConnectError::Timeout) => { + Response::GatewayTimeout() + } + SendRequestError::Connect(_) => Response::BadGateway(), + _ => Response::InternalServerError(), + } + .into() + } +} diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs new file mode 100644 index 000000000..2e29484ff --- /dev/null +++ b/actix-http/src/client/h1proto.rs @@ -0,0 +1,248 @@ +use std::{io, time}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use bytes::Bytes; +use futures::future::{ok, Either}; +use futures::{Async, Future, Poll, Sink, Stream}; + +use crate::error::PayloadError; +use crate::h1; +use crate::message::{RequestHead, ResponseHead}; +use crate::payload::{Payload, PayloadStream}; + +use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; +use super::error::{ConnectError, SendRequestError}; +use super::pool::Acquired; +use crate::body::{BodyLength, MessageBody}; + +pub(crate) fn send_request( + io: T, + head: RequestHead, + body: B, + created: time::Instant, + pool: Option>, +) -> impl Future +where + T: AsyncRead + AsyncWrite + 'static, + B: MessageBody, +{ + let io = H1Connection { + created, + pool, + io: Some(io), + }; + + let len = body.length(); + + // create Framed and send reqest + Framed::new(io, h1::ClientCodec::default()) + .send((head, len).into()) + .from_err() + // send request body + .and_then(move |framed| match body.length() { + BodyLength::None | BodyLength::Empty | BodyLength::Sized(0) => { + Either::A(ok(framed)) + } + _ => Either::B(SendBody::new(body, framed)), + }) + // read response and init read body + .and_then(|framed| { + framed + .into_future() + .map_err(|(e, _)| SendRequestError::from(e)) + .and_then(|(item, framed)| { + if let Some(res) = item { + match framed.get_codec().message_type() { + h1::MessageType::None => { + let force_close = !framed.get_codec().keepalive(); + release_connection(framed, force_close); + Ok((res, Payload::None)) + } + _ => { + let pl: PayloadStream = Box::new(PlStream::new(framed)); + Ok((res, pl.into())) + } + } + } else { + Err(ConnectError::Disconnected.into()) + } + }) + }) +} + +#[doc(hidden)] +/// HTTP client connection +pub struct H1Connection { + io: Option, + created: time::Instant, + pool: Option>, +} + +impl ConnectionLifetime for H1Connection { + /// Close connection + fn close(&mut self) { + if let Some(mut pool) = self.pool.take() { + if let Some(io) = self.io.take() { + pool.close(IoConnection::new( + ConnectionType::H1(io), + self.created, + None, + )); + } + } + } + + /// Release this connection to the connection pool + fn release(&mut self) { + if let Some(mut pool) = self.pool.take() { + if let Some(io) = self.io.take() { + pool.release(IoConnection::new( + ConnectionType::H1(io), + self.created, + None, + )); + } + } + } +} + +impl io::Read for H1Connection { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.io.as_mut().unwrap().read(buf) + } +} + +impl AsyncRead for H1Connection {} + +impl io::Write for H1Connection { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.io.as_mut().unwrap().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.io.as_mut().unwrap().flush() + } +} + +impl AsyncWrite for H1Connection { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.io.as_mut().unwrap().shutdown() + } +} + +/// Future responsible for sending request body to the peer +pub(crate) struct SendBody { + body: Option, + framed: Option>, + flushed: bool, +} + +impl SendBody +where + I: AsyncRead + AsyncWrite + 'static, + B: MessageBody, +{ + pub(crate) fn new(body: B, framed: Framed) -> Self { + SendBody { + body: Some(body), + framed: Some(framed), + flushed: true, + } + } +} + +impl Future for SendBody +where + I: ConnectionLifetime, + B: MessageBody, +{ + type Item = Framed; + type Error = SendRequestError; + + fn poll(&mut self) -> Poll { + let mut body_ready = true; + loop { + while body_ready + && self.body.is_some() + && !self.framed.as_ref().unwrap().is_write_buf_full() + { + match self.body.as_mut().unwrap().poll_next()? { + Async::Ready(item) => { + // check if body is done + if item.is_none() { + let _ = self.body.take(); + } + self.flushed = false; + self.framed + .as_mut() + .unwrap() + .force_send(h1::Message::Chunk(item))?; + break; + } + Async::NotReady => body_ready = false, + } + } + + if !self.flushed { + match self.framed.as_mut().unwrap().poll_complete()? { + Async::Ready(_) => { + self.flushed = true; + continue; + } + Async::NotReady => return Ok(Async::NotReady), + } + } + + if self.body.is_none() { + return Ok(Async::Ready(self.framed.take().unwrap())); + } + return Ok(Async::NotReady); + } + } +} + +pub(crate) struct PlStream { + framed: Option>, +} + +impl PlStream { + fn new(framed: Framed) -> Self { + PlStream { + framed: Some(framed.map_codec(|codec| codec.into_payload_codec())), + } + } +} + +impl Stream for PlStream { + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll, Self::Error> { + match self.framed.as_mut().unwrap().poll()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(Some(chunk)) => { + if let Some(chunk) = chunk { + Ok(Async::Ready(Some(chunk))) + } else { + let framed = self.framed.take().unwrap(); + let force_close = framed.get_codec().keepalive(); + release_connection(framed, force_close); + Ok(Async::Ready(None)) + } + } + Async::Ready(None) => Ok(Async::Ready(None)), + } + } +} + +fn release_connection(framed: Framed, force_close: bool) +where + T: ConnectionLifetime, +{ + let mut parts = framed.into_parts(); + if !force_close && parts.read_buf.is_empty() && parts.write_buf.is_empty() { + parts.io.release() + } else { + parts.io.close() + } +} diff --git a/actix-http/src/client/h2proto.rs b/actix-http/src/client/h2proto.rs new file mode 100644 index 000000000..9ad722627 --- /dev/null +++ b/actix-http/src/client/h2proto.rs @@ -0,0 +1,185 @@ +use std::time; + +use actix_codec::{AsyncRead, AsyncWrite}; +use bytes::Bytes; +use futures::future::{err, Either}; +use futures::{Async, Future, Poll}; +use h2::{client::SendRequest, SendStream}; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, TRANSFER_ENCODING}; +use http::{request::Request, HttpTryFrom, Method, Version}; + +use crate::body::{BodyLength, MessageBody}; +use crate::message::{RequestHead, ResponseHead}; +use crate::payload::Payload; + +use super::connection::{ConnectionType, IoConnection}; +use super::error::SendRequestError; +use super::pool::Acquired; + +pub(crate) fn send_request( + io: SendRequest, + head: RequestHead, + body: B, + created: time::Instant, + pool: Option>, +) -> impl Future +where + T: AsyncRead + AsyncWrite + 'static, + B: MessageBody, +{ + trace!("Sending client request: {:?} {:?}", head, body.length()); + let head_req = head.method == Method::HEAD; + let length = body.length(); + let eof = match length { + BodyLength::None | BodyLength::Empty | BodyLength::Sized(0) => true, + _ => false, + }; + + io.ready() + .map_err(SendRequestError::from) + .and_then(move |mut io| { + let mut req = Request::new(()); + *req.uri_mut() = head.uri; + *req.method_mut() = head.method; + *req.version_mut() = Version::HTTP_2; + + let mut skip_len = true; + // let mut has_date = false; + + // Content length + let _ = match length { + BodyLength::None => None, + BodyLength::Stream => { + skip_len = false; + None + } + BodyLength::Empty => req + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodyLength::Sized(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + BodyLength::Sized64(len) => req.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + }; + + // copy headers + for (key, value) in head.headers.iter() { + match *key { + CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + CONTENT_LENGTH if skip_len => continue, + // DATE => has_date = true, + _ => (), + } + req.headers_mut().append(key, value.clone()); + } + + match io.send_request(req, eof) { + Ok((res, send)) => { + release(io, pool, created, false); + + if !eof { + Either::A(Either::B( + SendBody { + body, + send, + buf: None, + } + .and_then(move |_| res.map_err(SendRequestError::from)), + )) + } else { + Either::B(res.map_err(SendRequestError::from)) + } + } + Err(e) => { + release(io, pool, created, e.is_io()); + Either::A(Either::A(err(e.into()))) + } + } + }) + .and_then(move |resp| { + let (parts, body) = resp.into_parts(); + let payload = if head_req { Payload::None } else { body.into() }; + + let mut head = ResponseHead::default(); + head.version = parts.version; + head.status = parts.status; + head.headers = parts.headers; + + Ok((head, payload)) + }) + .from_err() +} + +struct SendBody { + body: B, + send: SendStream, + buf: Option, +} + +impl Future for SendBody { + type Item = (); + type Error = SendRequestError; + + fn poll(&mut self) -> Poll { + loop { + if self.buf.is_none() { + match self.body.poll_next() { + Ok(Async::Ready(Some(buf))) => { + self.send.reserve_capacity(buf.len()); + self.buf = Some(buf); + } + Ok(Async::Ready(None)) => { + if let Err(e) = self.send.send_data(Bytes::new(), true) { + return Err(e.into()); + } + self.send.reserve_capacity(0); + return Ok(Async::Ready(())); + } + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(e) => return Err(e.into()), + } + } + + match self.send.poll_capacity() { + Ok(Async::NotReady) => return Ok(Async::NotReady), + Ok(Async::Ready(None)) => return Ok(Async::Ready(())), + Ok(Async::Ready(Some(cap))) => { + let mut buf = self.buf.take().unwrap(); + let len = buf.len(); + let bytes = buf.split_to(std::cmp::min(cap, len)); + + if let Err(e) = self.send.send_data(bytes, false) { + return Err(e.into()); + } else { + if !buf.is_empty() { + self.send.reserve_capacity(buf.len()); + self.buf = Some(buf); + } + continue; + } + } + Err(e) => return Err(e.into()), + } + } + } +} + +// release SendRequest object +fn release( + io: SendRequest, + pool: Option>, + created: time::Instant, + close: bool, +) { + if let Some(mut pool) = pool { + if close { + pool.close(IoConnection::new(ConnectionType::H2(io), created, None)); + } else { + pool.release(IoConnection::new(ConnectionType::H2(io), created, None)); + } + } +} diff --git a/actix-http/src/client/mod.rs b/actix-http/src/client/mod.rs new file mode 100644 index 000000000..87c374742 --- /dev/null +++ b/actix-http/src/client/mod.rs @@ -0,0 +1,11 @@ +//! Http client api +mod connection; +mod connector; +mod error; +mod h1proto; +mod h2proto; +mod pool; + +pub use self::connection::Connection; +pub use self::connector::Connector; +pub use self::error::{ConnectError, InvalidUrl, SendRequestError}; diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs new file mode 100644 index 000000000..a94b1e52a --- /dev/null +++ b/actix-http/src/client/pool.rs @@ -0,0 +1,538 @@ +use std::cell::RefCell; +use std::collections::VecDeque; +use std::io; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_service::Service; +use bytes::Bytes; +use futures::future::{err, ok, Either, FutureResult}; +use futures::task::AtomicTask; +use futures::unsync::oneshot; +use futures::{Async, Future, Poll}; +use h2::client::{handshake, Handshake}; +use hashbrown::HashMap; +use http::uri::{Authority, Uri}; +use indexmap::IndexSet; +use slab::Slab; +use tokio_timer::{sleep, Delay}; + +use super::connection::{ConnectionType, IoConnection}; +use super::error::ConnectError; + +#[derive(Clone, Copy, PartialEq)] +pub enum Protocol { + Http1, + Http2, +} + +#[derive(Hash, Eq, PartialEq, Clone, Debug)] +pub(crate) struct Key { + authority: Authority, +} + +impl From for Key { + fn from(authority: Authority) -> Key { + Key { authority } + } +} + +/// Connections pool +pub(crate) struct ConnectionPool( + T, + Rc>>, +); + +impl ConnectionPool +where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, +{ + pub(crate) fn new( + connector: T, + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Option, + limit: usize, + ) -> Self { + ConnectionPool( + connector, + Rc::new(RefCell::new(Inner { + conn_lifetime, + conn_keep_alive, + disconnect_timeout, + limit, + acquired: 0, + waiters: Slab::new(), + waiters_queue: IndexSet::new(), + available: HashMap::new(), + task: AtomicTask::new(), + })), + ) + } +} + +impl Clone for ConnectionPool +where + T: Clone, + Io: AsyncRead + AsyncWrite + 'static, +{ + fn clone(&self) -> Self { + ConnectionPool(self.0.clone(), self.1.clone()) + } +} + +impl Service for ConnectionPool +where + Io: AsyncRead + AsyncWrite + 'static, + T: Service, +{ + type Request = Uri; + type Response = IoConnection; + type Error = ConnectError; + type Future = Either< + FutureResult, + Either, OpenConnection>, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.0.poll_ready() + } + + fn call(&mut self, req: Uri) -> Self::Future { + let key = if let Some(authority) = req.authority_part() { + authority.clone().into() + } else { + return Either::A(err(ConnectError::Unresolverd)); + }; + + // acquire connection + match self.1.as_ref().borrow_mut().acquire(&key) { + Acquire::Acquired(io, created) => { + // use existing connection + Either::A(ok(IoConnection::new( + io, + created, + Some(Acquired(key, Some(self.1.clone()))), + ))) + } + Acquire::NotAvailable => { + // connection is not available, wait + let (rx, token) = self.1.as_ref().borrow_mut().wait_for(req); + Either::B(Either::A(WaitForConnection { + rx, + key, + token, + inner: Some(self.1.clone()), + })) + } + Acquire::Available => { + // open new connection + Either::B(Either::B(OpenConnection::new( + key, + self.1.clone(), + self.0.call(req), + ))) + } + } + } +} + +#[doc(hidden)] +pub struct WaitForConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + key: Key, + token: usize, + rx: oneshot::Receiver, ConnectError>>, + inner: Option>>>, +} + +impl Drop for WaitForConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fn drop(&mut self) { + if let Some(i) = self.inner.take() { + let mut inner = i.as_ref().borrow_mut(); + inner.release_waiter(&self.key, self.token); + inner.check_availibility(); + } + } +} + +impl Future for WaitForConnection +where + Io: AsyncRead + AsyncWrite, +{ + type Item = IoConnection; + type Error = ConnectError; + + fn poll(&mut self) -> Poll { + match self.rx.poll() { + Ok(Async::Ready(item)) => match item { + Err(err) => Err(err), + Ok(conn) => { + let _ = self.inner.take(); + Ok(Async::Ready(conn)) + } + }, + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => { + let _ = self.inner.take(); + Err(ConnectError::Disconnected) + } + } + } +} + +#[doc(hidden)] +pub struct OpenConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fut: F, + key: Key, + h2: Option>, + inner: Option>>>, +} + +impl OpenConnection +where + F: Future, + Io: AsyncRead + AsyncWrite + 'static, +{ + fn new(key: Key, inner: Rc>>, fut: F) -> Self { + OpenConnection { + key, + fut, + inner: Some(inner), + h2: None, + } + } +} + +impl Drop for OpenConnection +where + Io: AsyncRead + AsyncWrite + 'static, +{ + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + let mut inner = inner.as_ref().borrow_mut(); + inner.release(); + inner.check_availibility(); + } + } +} + +impl Future for OpenConnection +where + F: Future, + Io: AsyncRead + AsyncWrite, +{ + type Item = IoConnection; + type Error = ConnectError; + + fn poll(&mut self) -> Poll { + if let Some(ref mut h2) = self.h2 { + return match h2.poll() { + Ok(Async::Ready((snd, connection))) => { + tokio_current_thread::spawn(connection.map_err(|_| ())); + Ok(Async::Ready(IoConnection::new( + ConnectionType::H2(snd), + Instant::now(), + Some(Acquired(self.key.clone(), self.inner.clone())), + ))) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(e) => Err(e.into()), + }; + } + + match self.fut.poll() { + Err(err) => Err(err), + Ok(Async::Ready((io, proto))) => { + let _ = self.inner.take(); + if proto == Protocol::Http1 { + Ok(Async::Ready(IoConnection::new( + ConnectionType::H1(io), + Instant::now(), + Some(Acquired(self.key.clone(), self.inner.clone())), + ))) + } else { + self.h2 = Some(handshake(io)); + self.poll() + } + } + Ok(Async::NotReady) => Ok(Async::NotReady), + } + } +} + +enum Acquire { + Acquired(ConnectionType, Instant), + Available, + NotAvailable, +} + +// #[derive(Debug)] +struct AvailableConnection { + io: ConnectionType, + used: Instant, + created: Instant, +} + +pub(crate) struct Inner { + conn_lifetime: Duration, + conn_keep_alive: Duration, + disconnect_timeout: Option, + limit: usize, + acquired: usize, + available: HashMap>>, + waiters: Slab<(Uri, oneshot::Sender, ConnectError>>)>, + waiters_queue: IndexSet<(Key, usize)>, + task: AtomicTask, +} + +impl Inner { + fn reserve(&mut self) { + self.acquired += 1; + } + + fn release(&mut self) { + self.acquired -= 1; + } + + fn release_waiter(&mut self, key: &Key, token: usize) { + self.waiters.remove(token); + self.waiters_queue.remove(&(key.clone(), token)); + } + + fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) { + self.acquired -= 1; + self.available + .entry(key.clone()) + .or_insert_with(VecDeque::new) + .push_back(AvailableConnection { + io, + created, + used: Instant::now(), + }); + } +} + +impl Inner +where + Io: AsyncRead + AsyncWrite + 'static, +{ + /// connection is not available, wait + fn wait_for( + &mut self, + connect: Uri, + ) -> ( + oneshot::Receiver, ConnectError>>, + usize, + ) { + let (tx, rx) = oneshot::channel(); + + let key: Key = connect.authority_part().unwrap().clone().into(); + let entry = self.waiters.vacant_entry(); + let token = entry.key(); + entry.insert((connect, tx)); + assert!(!self.waiters_queue.insert((key, token))); + (rx, token) + } + + fn acquire(&mut self, key: &Key) -> Acquire { + // check limits + if self.limit > 0 && self.acquired >= self.limit { + return Acquire::NotAvailable; + } + + self.reserve(); + + // check if open connection is available + // cleanup stale connections at the same time + if let Some(ref mut connections) = self.available.get_mut(key) { + let now = Instant::now(); + while let Some(conn) = connections.pop_back() { + // check if it still usable + if (now - conn.used) > self.conn_keep_alive + || (now - conn.created) > self.conn_lifetime + { + if let Some(timeout) = self.disconnect_timeout { + if let ConnectionType::H1(io) = conn.io { + tokio_current_thread::spawn(CloseConnection::new( + io, timeout, + )) + } + } + } else { + let mut io = conn.io; + let mut buf = [0; 2]; + if let ConnectionType::H1(ref mut s) = io { + match s.read(&mut buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), + Ok(n) if n > 0 => { + if let Some(timeout) = self.disconnect_timeout { + if let ConnectionType::H1(io) = io { + tokio_current_thread::spawn( + CloseConnection::new(io, timeout), + ) + } + } + continue; + } + Ok(_) | Err(_) => continue, + } + } + return Acquire::Acquired(io, conn.created); + } + } + } + Acquire::Available + } + + fn release_close(&mut self, io: ConnectionType) { + self.acquired -= 1; + if let Some(timeout) = self.disconnect_timeout { + if let ConnectionType::H1(io) = io { + tokio_current_thread::spawn(CloseConnection::new(io, timeout)) + } + } + } + + fn check_availibility(&self) { + if !self.waiters_queue.is_empty() && self.acquired < self.limit { + self.task.notify() + } + } +} + +// struct ConnectorPoolSupport +// where +// Io: AsyncRead + AsyncWrite + 'static, +// { +// connector: T, +// inner: Rc>>, +// } + +// impl Future for ConnectorPoolSupport +// where +// Io: AsyncRead + AsyncWrite + 'static, +// T: Service, +// T::Future: 'static, +// { +// type Item = (); +// type Error = (); + +// fn poll(&mut self) -> Poll { +// let mut inner = self.inner.as_ref().borrow_mut(); +// inner.task.register(); + +// // check waiters +// loop { +// let (key, token) = { +// if let Some((key, token)) = inner.waiters_queue.get_index(0) { +// (key.clone(), *token) +// } else { +// break; +// } +// }; +// match inner.acquire(&key) { +// Acquire::NotAvailable => break, +// Acquire::Acquired(io, created) => { +// let (_, tx) = inner.waiters.remove(token); +// if let Err(conn) = tx.send(Ok(IoConnection::new( +// io, +// created, +// Some(Acquired(key.clone(), Some(self.inner.clone()))), +// ))) { +// let (io, created) = conn.unwrap().into_inner(); +// inner.release_conn(&key, io, created); +// } +// } +// Acquire::Available => { +// let (connect, tx) = inner.waiters.remove(token); +// OpenWaitingConnection::spawn( +// key.clone(), +// tx, +// self.inner.clone(), +// self.connector.call(connect), +// ); +// } +// } +// let _ = inner.waiters_queue.swap_remove_index(0); +// } + +// Ok(Async::NotReady) +// } +// } + +struct CloseConnection { + io: T, + timeout: Delay, +} + +impl CloseConnection +where + T: AsyncWrite, +{ + fn new(io: T, timeout: Duration) -> Self { + CloseConnection { + io, + timeout: sleep(timeout), + } + } +} + +impl Future for CloseConnection +where + T: AsyncWrite, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll<(), ()> { + match self.timeout.poll() { + Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), + Ok(Async::NotReady) => match self.io.shutdown() { + Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())), + Ok(Async::NotReady) => Ok(Async::NotReady), + }, + } + } +} + +pub(crate) struct Acquired(Key, Option>>>); + +impl Acquired +where + T: AsyncRead + AsyncWrite + 'static, +{ + pub(crate) fn close(&mut self, conn: IoConnection) { + if let Some(inner) = self.1.take() { + let (io, _) = conn.into_inner(); + inner.as_ref().borrow_mut().release_close(io); + } + } + pub(crate) fn release(&mut self, conn: IoConnection) { + if let Some(inner) = self.1.take() { + let (io, created) = conn.into_inner(); + inner + .as_ref() + .borrow_mut() + .release_conn(&self.0, io, created); + } + } +} + +impl Drop for Acquired { + fn drop(&mut self) { + if let Some(inner) = self.1.take() { + inner.as_ref().borrow_mut().release(); + } + } +} diff --git a/actix-http/src/config.rs b/actix-http/src/config.rs new file mode 100644 index 000000000..f7d7f5f94 --- /dev/null +++ b/actix-http/src/config.rs @@ -0,0 +1,290 @@ +use std::cell::UnsafeCell; +use std::fmt; +use std::fmt::Write; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use bytes::BytesMut; +use futures::{future, Future}; +use time; +use tokio_timer::{sleep, Delay}; + +// "Sun, 06 Nov 1994 08:49:37 GMT".len() +const DATE_VALUE_LENGTH: usize = 29; + +#[derive(Debug, PartialEq, Clone, Copy)] +/// Server keep-alive setting +pub enum KeepAlive { + /// Keep alive in seconds + Timeout(usize), + /// Relay on OS to shutdown tcp connection + Os, + /// Disabled + Disabled, +} + +impl From for KeepAlive { + fn from(keepalive: usize) -> Self { + KeepAlive::Timeout(keepalive) + } +} + +impl From> for KeepAlive { + fn from(keepalive: Option) -> Self { + if let Some(keepalive) = keepalive { + KeepAlive::Timeout(keepalive) + } else { + KeepAlive::Disabled + } + } +} + +/// Http service configuration +pub struct ServiceConfig(Rc); + +struct Inner { + keep_alive: Option, + client_timeout: u64, + client_disconnect: u64, + ka_enabled: bool, + timer: DateService, +} + +impl Clone for ServiceConfig { + fn clone(&self) -> Self { + ServiceConfig(self.0.clone()) + } +} + +impl Default for ServiceConfig { + fn default() -> Self { + Self::new(KeepAlive::Timeout(5), 0, 0) + } +} + +impl ServiceConfig { + /// Create instance of `ServiceConfig` + pub fn new( + keep_alive: KeepAlive, + client_timeout: u64, + client_disconnect: u64, + ) -> ServiceConfig { + let (keep_alive, ka_enabled) = match keep_alive { + KeepAlive::Timeout(val) => (val as u64, true), + KeepAlive::Os => (0, true), + KeepAlive::Disabled => (0, false), + }; + let keep_alive = if ka_enabled && keep_alive > 0 { + Some(Duration::from_secs(keep_alive)) + } else { + None + }; + + ServiceConfig(Rc::new(Inner { + keep_alive, + ka_enabled, + client_timeout, + client_disconnect, + timer: DateService::new(), + })) + } + + #[inline] + /// Keep alive duration if configured. + pub fn keep_alive(&self) -> Option { + self.0.keep_alive + } + + #[inline] + /// Return state of connection keep-alive funcitonality + pub fn keep_alive_enabled(&self) -> bool { + self.0.ka_enabled + } + + #[inline] + /// Client timeout for first request. + pub fn client_timer(&self) -> Option { + let delay = self.0.client_timeout; + if delay != 0 { + Some(Delay::new( + self.0.timer.now() + Duration::from_millis(delay), + )) + } else { + None + } + } + + /// Client timeout for first request. + pub fn client_timer_expire(&self) -> Option { + let delay = self.0.client_timeout; + if delay != 0 { + Some(self.0.timer.now() + Duration::from_millis(delay)) + } else { + None + } + } + + /// Client disconnect timer + pub fn client_disconnect_timer(&self) -> Option { + let delay = self.0.client_disconnect; + if delay != 0 { + Some(self.0.timer.now() + Duration::from_millis(delay)) + } else { + None + } + } + + #[inline] + /// Return keep-alive timer delay is configured. + pub fn keep_alive_timer(&self) -> Option { + if let Some(ka) = self.0.keep_alive { + Some(Delay::new(self.0.timer.now() + ka)) + } else { + None + } + } + + /// Keep-alive expire time + pub fn keep_alive_expire(&self) -> Option { + if let Some(ka) = self.0.keep_alive { + Some(self.0.timer.now() + ka) + } else { + None + } + } + + #[inline] + pub(crate) fn now(&self) -> Instant { + self.0.timer.now() + } + + pub(crate) fn set_date(&self, dst: &mut BytesMut) { + let mut buf: [u8; 39] = [0; 39]; + buf[..6].copy_from_slice(b"date: "); + buf[6..35].copy_from_slice(&self.0.timer.date().bytes); + buf[35..].copy_from_slice(b"\r\n\r\n"); + dst.extend_from_slice(&buf); + } + + pub(crate) fn set_date_header(&self, dst: &mut BytesMut) { + dst.extend_from_slice(&self.0.timer.date().bytes); + } +} + +struct Date { + bytes: [u8; DATE_VALUE_LENGTH], + pos: usize, +} + +impl Date { + fn new() -> Date { + let mut date = Date { + bytes: [0; DATE_VALUE_LENGTH], + pos: 0, + }; + date.update(); + date + } + fn update(&mut self) { + self.pos = 0; + write!(self, "{}", time::at_utc(time::get_time()).rfc822()).unwrap(); + } +} + +impl fmt::Write for Date { + fn write_str(&mut self, s: &str) -> fmt::Result { + let len = s.len(); + self.bytes[self.pos..self.pos + len].copy_from_slice(s.as_bytes()); + self.pos += len; + Ok(()) + } +} + +#[derive(Clone)] +struct DateService(Rc); + +struct DateServiceInner { + current: UnsafeCell>, +} + +impl DateServiceInner { + fn new() -> Self { + DateServiceInner { + current: UnsafeCell::new(None), + } + } + + fn get_ref(&self) -> &Option<(Date, Instant)> { + unsafe { &*self.current.get() } + } + + fn reset(&self) { + unsafe { (&mut *self.current.get()).take() }; + } + + fn update(&self) { + let now = Instant::now(); + let date = Date::new(); + *(unsafe { &mut *self.current.get() }) = Some((date, now)); + } +} + +impl DateService { + fn new() -> Self { + DateService(Rc::new(DateServiceInner::new())) + } + + fn check_date(&self) { + if self.0.get_ref().is_none() { + self.0.update(); + + // periodic date update + let s = self.clone(); + tokio_current_thread::spawn(sleep(Duration::from_millis(500)).then( + move |_| { + s.0.reset(); + future::ok(()) + }, + )); + } + } + + fn now(&self) -> Instant { + self.check_date(); + self.0.get_ref().as_ref().unwrap().1 + } + + fn date(&self) -> &Date { + self.check_date(); + + let item = self.0.get_ref().as_ref().unwrap(); + &item.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_rt::System; + use futures::future; + + #[test] + fn test_date_len() { + assert_eq!(DATE_VALUE_LENGTH, "Sun, 06 Nov 1994 08:49:37 GMT".len()); + } + + #[test] + fn test_date() { + let mut rt = System::new("test"); + + let _ = rt.block_on(future::lazy(|| { + let settings = ServiceConfig::new(KeepAlive::Os, 0, 0); + let mut buf1 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf1); + let mut buf2 = BytesMut::with_capacity(DATE_VALUE_LENGTH + 10); + settings.set_date(&mut buf2); + assert_eq!(buf1, buf2); + future::ok::<_, ()>(()) + })); + } +} diff --git a/actix-http/src/error.rs b/actix-http/src/error.rs new file mode 100644 index 000000000..820071b1e --- /dev/null +++ b/actix-http/src/error.rs @@ -0,0 +1,1118 @@ +//! Error and Result module +use std::cell::RefCell; +use std::io::Error as IoError; +use std::str::Utf8Error; +use std::string::FromUtf8Error; +use std::{fmt, io, result}; + +// use actix::MailboxError; +use actix_utils::timeout::TimeoutError; +use backtrace::Backtrace; +#[cfg(feature = "cookies")] +use cookie; +use derive_more::{Display, From}; +use futures::Canceled; +use http::uri::InvalidUri; +use http::{header, Error as HttpError, StatusCode}; +use httparse; +use serde::de::value::Error as DeError; +use serde_json::error::Error as JsonError; +use serde_urlencoded::ser::Error as FormError; +use tokio_timer::Error as TimerError; + +// re-export for convinience +#[cfg(feature = "cookies")] +pub use cookie::ParseError as CookieParseError; + +use crate::body::Body; +use crate::response::Response; + +/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) +/// for actix web operations +/// +/// This typedef is generally used to avoid writing out +/// `actix_http::error::Error` directly and is otherwise a direct mapping to +/// `Result`. +pub type Result = result::Result; + +/// General purpose actix web error. +/// +/// An actix web error is used to carry errors from `failure` or `std::error` +/// through actix in a convenient way. It can be created through +/// converting errors with `into()`. +/// +/// Whenever it is created from an external object a response error is created +/// for it that can be used to create an http response from it this means that +/// if you have access to an actix `Error` you can always get a +/// `ResponseError` reference from it. +pub struct Error { + cause: Box, + backtrace: Option, +} + +impl Error { + /// Returns the reference to the underlying `ResponseError`. + pub fn as_response_error(&self) -> &ResponseError { + self.cause.as_ref() + } + + /// Returns a reference to the Backtrace carried by this error, if it + /// carries one. + /// + /// This uses the same `Backtrace` type that `failure` uses. + pub fn backtrace(&self) -> &Backtrace { + if let Some(bt) = self.cause.backtrace() { + bt + } else { + self.backtrace.as_ref().unwrap() + } + } + + /// Converts error to a response instance and set error message as response body + pub fn response_with_message(self) -> Response { + let message = format!("{}", self); + let resp: Response = self.into(); + resp.set_body(Body::from(message)) + } +} + +/// Error that can be converted to `Response` +pub trait ResponseError: fmt::Debug + fmt::Display { + /// Create response for error + /// + /// Internal server error is generated by default. + fn error_response(&self) -> Response { + Response::new(StatusCode::INTERNAL_SERVER_ERROR) + } + + /// Response + fn backtrace(&self) -> Option<&Backtrace> { + None + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.cause, f) + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(bt) = self.cause.backtrace() { + write!(f, "{:?}\n\n{:?}", &self.cause, bt) + } else { + write!( + f, + "{:?}\n\n{:?}", + &self.cause, + self.backtrace.as_ref().unwrap() + ) + } + } +} + +/// Convert `Error` to a `Response` instance +impl From for Response { + fn from(err: Error) -> Self { + Response::from_error(err) + } +} + +/// `Error` for any error that implements `ResponseError` +impl From for Error { + fn from(err: T) -> Error { + let backtrace = if err.backtrace().is_none() { + Some(Backtrace::new()) + } else { + None + }; + Error { + cause: Box::new(err), + backtrace, + } + } +} + +/// Return `GATEWAY_TIMEOUT` for `TimeoutError` +impl ResponseError for TimeoutError { + fn error_response(&self) -> Response { + match self { + TimeoutError::Service(e) => e.error_response(), + TimeoutError::Timeout => Response::new(StatusCode::GATEWAY_TIMEOUT), + } + } +} + +/// `InternalServerError` for `JsonError` +impl ResponseError for JsonError {} + +/// `InternalServerError` for `FormError` +impl ResponseError for FormError {} + +/// `InternalServerError` for `TimerError` +impl ResponseError for TimerError {} + +/// Return `BAD_REQUEST` for `de::value::Error` +impl ResponseError for DeError { + fn error_response(&self) -> Response { + Response::new(StatusCode::BAD_REQUEST) + } +} + +/// Return `BAD_REQUEST` for `Utf8Error` +impl ResponseError for Utf8Error { + fn error_response(&self) -> Response { + Response::new(StatusCode::BAD_REQUEST) + } +} + +/// Return `InternalServerError` for `HttpError`, +/// Response generation can return `HttpError`, so it is internal error +impl ResponseError for HttpError {} + +/// Return `InternalServerError` for `io::Error` +impl ResponseError for io::Error { + fn error_response(&self) -> Response { + match self.kind() { + io::ErrorKind::NotFound => Response::new(StatusCode::NOT_FOUND), + io::ErrorKind::PermissionDenied => Response::new(StatusCode::FORBIDDEN), + _ => Response::new(StatusCode::INTERNAL_SERVER_ERROR), + } + } +} + +/// `BadRequest` for `InvalidHeaderValue` +impl ResponseError for header::InvalidHeaderValue { + fn error_response(&self) -> Response { + Response::new(StatusCode::BAD_REQUEST) + } +} + +/// `BadRequest` for `InvalidHeaderValue` +impl ResponseError for header::InvalidHeaderValueBytes { + fn error_response(&self) -> Response { + Response::new(StatusCode::BAD_REQUEST) + } +} + +/// `InternalServerError` for `futures::Canceled` +impl ResponseError for Canceled {} + +// /// `InternalServerError` for `actix::MailboxError` +// impl ResponseError for MailboxError {} + +/// A set of errors that can occur during parsing HTTP streams +#[derive(Debug, Display)] +pub enum ParseError { + /// An invalid `Method`, such as `GE.T`. + #[display(fmt = "Invalid Method specified")] + Method, + /// An invalid `Uri`, such as `exam ple.domain`. + #[display(fmt = "Uri error: {}", _0)] + Uri(InvalidUri), + /// An invalid `HttpVersion`, such as `HTP/1.1` + #[display(fmt = "Invalid HTTP version specified")] + Version, + /// An invalid `Header`. + #[display(fmt = "Invalid Header provided")] + Header, + /// A message head is too large to be reasonable. + #[display(fmt = "Message head is too large")] + TooLarge, + /// A message reached EOF, but is not complete. + #[display(fmt = "Message is incomplete")] + Incomplete, + /// An invalid `Status`, such as `1337 ELITE`. + #[display(fmt = "Invalid Status provided")] + Status, + /// A timeout occurred waiting for an IO event. + #[allow(dead_code)] + #[display(fmt = "Timeout")] + Timeout, + /// An `io::Error` that occurred while trying to read or write to a network + /// stream. + #[display(fmt = "IO error: {}", _0)] + Io(IoError), + /// Parsing a field as string failed + #[display(fmt = "UTF8 error: {}", _0)] + Utf8(Utf8Error), +} + +/// Return `BadRequest` for `ParseError` +impl ResponseError for ParseError { + fn error_response(&self) -> Response { + Response::new(StatusCode::BAD_REQUEST) + } +} + +impl From for ParseError { + fn from(err: IoError) -> ParseError { + ParseError::Io(err) + } +} + +impl From for ParseError { + fn from(err: InvalidUri) -> ParseError { + ParseError::Uri(err) + } +} + +impl From for ParseError { + fn from(err: Utf8Error) -> ParseError { + ParseError::Utf8(err) + } +} + +impl From for ParseError { + fn from(err: FromUtf8Error) -> ParseError { + ParseError::Utf8(err.utf8_error()) + } +} + +impl From for ParseError { + fn from(err: httparse::Error) -> ParseError { + match err { + httparse::Error::HeaderName + | httparse::Error::HeaderValue + | httparse::Error::NewLine + | httparse::Error::Token => ParseError::Header, + httparse::Error::Status => ParseError::Status, + httparse::Error::TooManyHeaders => ParseError::TooLarge, + httparse::Error::Version => ParseError::Version, + } + } +} + +#[derive(Display, Debug, From)] +/// A set of errors that can occur during payload parsing +pub enum PayloadError { + /// A payload reached EOF, but is not complete. + #[display(fmt = "A payload reached EOF, but is not complete.")] + Incomplete(Option), + /// Content encoding stream corruption + #[display(fmt = "Can not decode content-encoding.")] + EncodingCorrupted, + /// A payload reached size limit. + #[display(fmt = "A payload reached size limit.")] + Overflow, + /// A payload length is unknown. + #[display(fmt = "A payload length is unknown.")] + UnknownLength, + /// Http2 payload error + #[display(fmt = "{}", _0)] + Http2Payload(h2::Error), +} + +impl From for PayloadError { + fn from(err: io::Error) -> Self { + PayloadError::Incomplete(Some(err)) + } +} + +/// `PayloadError` returns two possible results: +/// +/// - `Overflow` returns `PayloadTooLarge` +/// - Other errors returns `BadRequest` +impl ResponseError for PayloadError { + fn error_response(&self) -> Response { + match *self { + PayloadError::Overflow => Response::new(StatusCode::PAYLOAD_TOO_LARGE), + _ => Response::new(StatusCode::BAD_REQUEST), + } + } +} + +/// Return `BadRequest` for `cookie::ParseError` +#[cfg(feature = "cookies")] +impl ResponseError for cookie::ParseError { + fn error_response(&self) -> Response { + Response::new(StatusCode::BAD_REQUEST) + } +} + +#[derive(Debug, Display, From)] +/// A set of errors that can occur during dispatching http requests +pub enum DispatchError { + /// Service error + Service, + + /// An `io::Error` that occurred while trying to read or write to a network + /// stream. + #[display(fmt = "IO error: {}", _0)] + Io(io::Error), + + /// Http request parse error. + #[display(fmt = "Parse error: {}", _0)] + Parse(ParseError), + + /// Http/2 error + #[display(fmt = "{}", _0)] + H2(h2::Error), + + /// The first request did not complete within the specified timeout. + #[display(fmt = "The first request did not complete within the specified timeout")] + SlowRequestTimeout, + + /// Disconnect timeout. Makes sense for ssl streams. + #[display(fmt = "Connection shutdown timeout")] + DisconnectTimeout, + + /// Payload is not consumed + #[display(fmt = "Task is completed but request's payload is not consumed")] + PayloadIsNotConsumed, + + /// Malformed request + #[display(fmt = "Malformed request")] + MalformedRequest, + + /// Internal error + #[display(fmt = "Internal error")] + InternalError, + + /// Unknown error + #[display(fmt = "Unknown error")] + Unknown, +} + +/// A set of error that can occure during parsing content type +#[derive(PartialEq, Debug, Display)] +pub enum ContentTypeError { + /// Can not parse content type + #[display(fmt = "Can not parse content type")] + ParseError, + /// Unknown content encoding + #[display(fmt = "Unknown content encoding")] + UnknownEncoding, +} + +/// Return `BadRequest` for `ContentTypeError` +impl ResponseError for ContentTypeError { + fn error_response(&self) -> Response { + Response::new(StatusCode::BAD_REQUEST) + } +} + +/// Helper type that can wrap any error and generate custom response. +/// +/// In following example any `io::Error` will be converted into "BAD REQUEST" +/// response as opposite to *INTERNAL SERVER ERROR* which is defined by +/// default. +/// +/// ```rust +/// # extern crate actix_http; +/// # use std::io; +/// # use actix_http::*; +/// +/// fn index(req: Request) -> Result<&'static str> { +/// Err(error::ErrorBadRequest(io::Error::new(io::ErrorKind::Other, "error"))) +/// } +/// # fn main() {} +/// ``` +pub struct InternalError { + cause: T, + status: InternalErrorType, + backtrace: Backtrace, +} + +enum InternalErrorType { + Status(StatusCode), + Response(RefCell>), +} + +impl InternalError { + /// Create `InternalError` instance + pub fn new(cause: T, status: StatusCode) -> Self { + InternalError { + cause, + status: InternalErrorType::Status(status), + backtrace: Backtrace::new(), + } + } + + /// Create `InternalError` with predefined `Response`. + pub fn from_response(cause: T, response: Response) -> Self { + InternalError { + cause, + status: InternalErrorType::Response(RefCell::new(Some(response))), + backtrace: Backtrace::new(), + } + } +} + +impl fmt::Debug for InternalError +where + T: fmt::Debug + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self.cause, f) + } +} + +impl fmt::Display for InternalError +where + T: fmt::Display + 'static, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.cause, f) + } +} + +impl ResponseError for InternalError +where + T: fmt::Debug + fmt::Display + 'static, +{ + fn backtrace(&self) -> Option<&Backtrace> { + Some(&self.backtrace) + } + + fn error_response(&self) -> Response { + match self.status { + InternalErrorType::Status(st) => Response::new(st), + InternalErrorType::Response(ref resp) => { + if let Some(resp) = resp.borrow_mut().take() { + resp + } else { + Response::new(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } + } +} + +/// Convert Response to a Error +impl From for Error { + fn from(res: Response) -> Error { + InternalError::from_response("", res).into() + } +} + +/// Helper function that creates wrapper of any error and generate *BAD +/// REQUEST* response. +#[allow(non_snake_case)] +pub fn ErrorBadRequest(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::BAD_REQUEST).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNAUTHORIZED* response. +#[allow(non_snake_case)] +pub fn ErrorUnauthorized(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNAUTHORIZED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PAYMENT_REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorPaymentRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PAYMENT_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate *FORBIDDEN* +/// response. +#[allow(non_snake_case)] +pub fn ErrorForbidden(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::FORBIDDEN).into() +} + +/// Helper function that creates wrapper of any error and generate *NOT FOUND* +/// response. +#[allow(non_snake_case)] +pub fn ErrorNotFound(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_FOUND).into() +} + +/// Helper function that creates wrapper of any error and generate *METHOD NOT +/// ALLOWED* response. +#[allow(non_snake_case)] +pub fn ErrorMethodNotAllowed(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::METHOD_NOT_ALLOWED).into() +} + +/// Helper function that creates wrapper of any error and generate *NOT +/// ACCEPTABLE* response. +#[allow(non_snake_case)] +pub fn ErrorNotAcceptable(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_ACCEPTABLE).into() +} + +/// Helper function that creates wrapper of any error and generate *PROXY +/// AUTHENTICATION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorProxyAuthenticationRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PROXY_AUTHENTICATION_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate *REQUEST +/// TIMEOUT* response. +#[allow(non_snake_case)] +pub fn ErrorRequestTimeout(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::REQUEST_TIMEOUT).into() +} + +/// Helper function that creates wrapper of any error and generate *CONFLICT* +/// response. +#[allow(non_snake_case)] +pub fn ErrorConflict(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::CONFLICT).into() +} + +/// Helper function that creates wrapper of any error and generate *GONE* +/// response. +#[allow(non_snake_case)] +pub fn ErrorGone(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::GONE).into() +} + +/// Helper function that creates wrapper of any error and generate *LENGTH +/// REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorLengthRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LENGTH_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PAYLOAD TOO LARGE* response. +#[allow(non_snake_case)] +pub fn ErrorPayloadTooLarge(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PAYLOAD_TOO_LARGE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *URI TOO LONG* response. +#[allow(non_snake_case)] +pub fn ErrorUriTooLong(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::URI_TOO_LONG).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNSUPPORTED MEDIA TYPE* response. +#[allow(non_snake_case)] +pub fn ErrorUnsupportedMediaType(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNSUPPORTED_MEDIA_TYPE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *RANGE NOT SATISFIABLE* response. +#[allow(non_snake_case)] +pub fn ErrorRangeNotSatisfiable(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::RANGE_NOT_SATISFIABLE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *IM A TEAPOT* response. +#[allow(non_snake_case)] +pub fn ErrorImATeapot(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::IM_A_TEAPOT).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *MISDIRECTED REQUEST* response. +#[allow(non_snake_case)] +pub fn ErrorMisdirectedRequest(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::MISDIRECTED_REQUEST).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNPROCESSABLE ENTITY* response. +#[allow(non_snake_case)] +pub fn ErrorUnprocessableEntity(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNPROCESSABLE_ENTITY).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *LOCKED* response. +#[allow(non_snake_case)] +pub fn ErrorLocked(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LOCKED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *FAILED DEPENDENCY* response. +#[allow(non_snake_case)] +pub fn ErrorFailedDependency(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::FAILED_DEPENDENCY).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UPGRADE REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorUpgradeRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UPGRADE_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PRECONDITION FAILED* response. +#[allow(non_snake_case)] +pub fn ErrorPreconditionFailed(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PRECONDITION_FAILED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *PRECONDITION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorPreconditionRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::PRECONDITION_REQUIRED).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *TOO MANY REQUESTS* response. +#[allow(non_snake_case)] +pub fn ErrorTooManyRequests(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::TOO_MANY_REQUESTS).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *REQUEST HEADER FIELDS TOO LARGE* response. +#[allow(non_snake_case)] +pub fn ErrorRequestHeaderFieldsTooLarge(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *UNAVAILABLE FOR LEGAL REASONS* response. +#[allow(non_snake_case)] +pub fn ErrorUnavailableForLegalReasons(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS).into() +} + +/// Helper function that creates wrapper of any error and generate +/// *EXPECTATION FAILED* response. +#[allow(non_snake_case)] +pub fn ErrorExpectationFailed(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::EXPECTATION_FAILED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *INTERNAL SERVER ERROR* response. +#[allow(non_snake_case)] +pub fn ErrorInternalServerError(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::INTERNAL_SERVER_ERROR).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NOT IMPLEMENTED* response. +#[allow(non_snake_case)] +pub fn ErrorNotImplemented(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_IMPLEMENTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *BAD GATEWAY* response. +#[allow(non_snake_case)] +pub fn ErrorBadGateway(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::BAD_GATEWAY).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *SERVICE UNAVAILABLE* response. +#[allow(non_snake_case)] +pub fn ErrorServiceUnavailable(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::SERVICE_UNAVAILABLE).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *GATEWAY TIMEOUT* response. +#[allow(non_snake_case)] +pub fn ErrorGatewayTimeout(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::GATEWAY_TIMEOUT).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *HTTP VERSION NOT SUPPORTED* response. +#[allow(non_snake_case)] +pub fn ErrorHttpVersionNotSupported(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::HTTP_VERSION_NOT_SUPPORTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *VARIANT ALSO NEGOTIATES* response. +#[allow(non_snake_case)] +pub fn ErrorVariantAlsoNegotiates(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::VARIANT_ALSO_NEGOTIATES).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *INSUFFICIENT STORAGE* response. +#[allow(non_snake_case)] +pub fn ErrorInsufficientStorage(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::INSUFFICIENT_STORAGE).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *LOOP DETECTED* response. +#[allow(non_snake_case)] +pub fn ErrorLoopDetected(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::LOOP_DETECTED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NOT EXTENDED* response. +#[allow(non_snake_case)] +pub fn ErrorNotExtended(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NOT_EXTENDED).into() +} + +/// Helper function that creates wrapper of any error and +/// generate *NETWORK AUTHENTICATION REQUIRED* response. +#[allow(non_snake_case)] +pub fn ErrorNetworkAuthenticationRequired(err: T) -> Error +where + T: fmt::Debug + fmt::Display + 'static, +{ + InternalError::new(err, StatusCode::NETWORK_AUTHENTICATION_REQUIRED).into() +} + +#[cfg(feature = "fail")] +mod failure_integration { + use super::*; + + /// Compatibility for `failure::Error` + impl ResponseError for failure::Error { + fn error_response(&self) -> Response { + Response::new(StatusCode::INTERNAL_SERVER_ERROR) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http::{Error as HttpError, StatusCode}; + use httparse; + use std::error::Error as StdError; + use std::io; + + #[test] + fn test_into_response() { + let resp: Response = ParseError::Incomplete.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let err: HttpError = StatusCode::from_u16(10000).err().unwrap().into(); + let resp: Response = err.error_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[cfg(feature = "cookies")] + #[test] + fn test_cookie_parse() { + use cookie::ParseError as CookieParseError; + let resp: Response = CookieParseError::EmptyName.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_as_response() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let e: Error = ParseError::Io(orig).into(); + assert_eq!(format!("{}", e.as_response_error()), "IO error: other"); + } + + #[test] + fn test_backtrace() { + let e = ErrorBadRequest("err"); + let _ = e.backtrace(); + } + + #[test] + fn test_error_cause() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let desc = orig.description().to_owned(); + let e = Error::from(orig); + assert_eq!(format!("{}", e.as_response_error()), desc); + } + + #[test] + fn test_error_display() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let desc = orig.description().to_owned(); + let e = Error::from(orig); + assert_eq!(format!("{}", e), desc); + } + + #[test] + fn test_error_http_response() { + let orig = io::Error::new(io::ErrorKind::Other, "other"); + let e = Error::from(orig); + let resp: Response = e.into(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + macro_rules! from { + ($from:expr => $error:pat) => { + match ParseError::from($from) { + e @ $error => { + assert!(format!("{}", e).len() >= 5); + } + e => unreachable!("{:?}", e), + } + }; + } + + macro_rules! from_and_cause { + ($from:expr => $error:pat) => { + match ParseError::from($from) { + e @ $error => { + let desc = format!("{}", e); + assert_eq!(desc, format!("IO error: {}", $from.description())); + } + _ => unreachable!("{:?}", $from), + } + }; + } + + #[test] + fn test_from() { + from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => ParseError::Io(..)); + from!(httparse::Error::HeaderName => ParseError::Header); + from!(httparse::Error::HeaderName => ParseError::Header); + from!(httparse::Error::HeaderValue => ParseError::Header); + from!(httparse::Error::NewLine => ParseError::Header); + from!(httparse::Error::Status => ParseError::Status); + from!(httparse::Error::Token => ParseError::Header); + from!(httparse::Error::TooManyHeaders => ParseError::TooLarge); + from!(httparse::Error::Version => ParseError::Version); + } + + #[test] + fn test_internal_error() { + let err = + InternalError::from_response(ParseError::Method, Response::Ok().into()); + let resp: Response = err.error_response(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_error_helpers() { + let r: Response = ErrorBadRequest("err").into(); + assert_eq!(r.status(), StatusCode::BAD_REQUEST); + + let r: Response = ErrorUnauthorized("err").into(); + assert_eq!(r.status(), StatusCode::UNAUTHORIZED); + + let r: Response = ErrorPaymentRequired("err").into(); + assert_eq!(r.status(), StatusCode::PAYMENT_REQUIRED); + + let r: Response = ErrorForbidden("err").into(); + assert_eq!(r.status(), StatusCode::FORBIDDEN); + + let r: Response = ErrorNotFound("err").into(); + assert_eq!(r.status(), StatusCode::NOT_FOUND); + + let r: Response = ErrorMethodNotAllowed("err").into(); + assert_eq!(r.status(), StatusCode::METHOD_NOT_ALLOWED); + + let r: Response = ErrorNotAcceptable("err").into(); + assert_eq!(r.status(), StatusCode::NOT_ACCEPTABLE); + + let r: Response = ErrorProxyAuthenticationRequired("err").into(); + assert_eq!(r.status(), StatusCode::PROXY_AUTHENTICATION_REQUIRED); + + let r: Response = ErrorRequestTimeout("err").into(); + assert_eq!(r.status(), StatusCode::REQUEST_TIMEOUT); + + let r: Response = ErrorConflict("err").into(); + assert_eq!(r.status(), StatusCode::CONFLICT); + + let r: Response = ErrorGone("err").into(); + assert_eq!(r.status(), StatusCode::GONE); + + let r: Response = ErrorLengthRequired("err").into(); + assert_eq!(r.status(), StatusCode::LENGTH_REQUIRED); + + let r: Response = ErrorPreconditionFailed("err").into(); + assert_eq!(r.status(), StatusCode::PRECONDITION_FAILED); + + let r: Response = ErrorPayloadTooLarge("err").into(); + assert_eq!(r.status(), StatusCode::PAYLOAD_TOO_LARGE); + + let r: Response = ErrorUriTooLong("err").into(); + assert_eq!(r.status(), StatusCode::URI_TOO_LONG); + + let r: Response = ErrorUnsupportedMediaType("err").into(); + assert_eq!(r.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); + + let r: Response = ErrorRangeNotSatisfiable("err").into(); + assert_eq!(r.status(), StatusCode::RANGE_NOT_SATISFIABLE); + + let r: Response = ErrorExpectationFailed("err").into(); + assert_eq!(r.status(), StatusCode::EXPECTATION_FAILED); + + let r: Response = ErrorImATeapot("err").into(); + assert_eq!(r.status(), StatusCode::IM_A_TEAPOT); + + let r: Response = ErrorMisdirectedRequest("err").into(); + assert_eq!(r.status(), StatusCode::MISDIRECTED_REQUEST); + + let r: Response = ErrorUnprocessableEntity("err").into(); + assert_eq!(r.status(), StatusCode::UNPROCESSABLE_ENTITY); + + let r: Response = ErrorLocked("err").into(); + assert_eq!(r.status(), StatusCode::LOCKED); + + let r: Response = ErrorFailedDependency("err").into(); + assert_eq!(r.status(), StatusCode::FAILED_DEPENDENCY); + + let r: Response = ErrorUpgradeRequired("err").into(); + assert_eq!(r.status(), StatusCode::UPGRADE_REQUIRED); + + let r: Response = ErrorPreconditionRequired("err").into(); + assert_eq!(r.status(), StatusCode::PRECONDITION_REQUIRED); + + let r: Response = ErrorTooManyRequests("err").into(); + assert_eq!(r.status(), StatusCode::TOO_MANY_REQUESTS); + + let r: Response = ErrorRequestHeaderFieldsTooLarge("err").into(); + assert_eq!(r.status(), StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE); + + let r: Response = ErrorUnavailableForLegalReasons("err").into(); + assert_eq!(r.status(), StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS); + + let r: Response = ErrorInternalServerError("err").into(); + assert_eq!(r.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let r: Response = ErrorNotImplemented("err").into(); + assert_eq!(r.status(), StatusCode::NOT_IMPLEMENTED); + + let r: Response = ErrorBadGateway("err").into(); + assert_eq!(r.status(), StatusCode::BAD_GATEWAY); + + let r: Response = ErrorServiceUnavailable("err").into(); + assert_eq!(r.status(), StatusCode::SERVICE_UNAVAILABLE); + + let r: Response = ErrorGatewayTimeout("err").into(); + assert_eq!(r.status(), StatusCode::GATEWAY_TIMEOUT); + + let r: Response = ErrorHttpVersionNotSupported("err").into(); + assert_eq!(r.status(), StatusCode::HTTP_VERSION_NOT_SUPPORTED); + + let r: Response = ErrorVariantAlsoNegotiates("err").into(); + assert_eq!(r.status(), StatusCode::VARIANT_ALSO_NEGOTIATES); + + let r: Response = ErrorInsufficientStorage("err").into(); + assert_eq!(r.status(), StatusCode::INSUFFICIENT_STORAGE); + + let r: Response = ErrorLoopDetected("err").into(); + assert_eq!(r.status(), StatusCode::LOOP_DETECTED); + + let r: Response = ErrorNotExtended("err").into(); + assert_eq!(r.status(), StatusCode::NOT_EXTENDED); + + let r: Response = ErrorNetworkAuthenticationRequired("err").into(); + assert_eq!(r.status(), StatusCode::NETWORK_AUTHENTICATION_REQUIRED); + } +} diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs new file mode 100644 index 000000000..148e4c18e --- /dev/null +++ b/actix-http/src/extensions.rs @@ -0,0 +1,91 @@ +use std::any::{Any, TypeId}; +use std::fmt; + +use hashbrown::HashMap; + +#[derive(Default)] +/// A type map of request extensions. +pub struct Extensions { + map: HashMap>, +} + +impl Extensions { + /// Create an empty `Extensions`. + #[inline] + pub fn new() -> Extensions { + Extensions { + map: HashMap::default(), + } + } + + /// Insert a type into this `Extensions`. + /// + /// If a extension of this type already existed, it will + /// be returned. + pub fn insert(&mut self, val: T) { + self.map.insert(TypeId::of::(), Box::new(val)); + } + + /// Check if container contains entry + pub fn contains(&self) -> bool { + self.map.get(&TypeId::of::()).is_some() + } + + /// Get a reference to a type previously inserted on this `Extensions`. + pub fn get(&self) -> Option<&T> { + self.map + .get(&TypeId::of::()) + .and_then(|boxed| (&**boxed as &(Any + 'static)).downcast_ref()) + } + + /// Get a mutable reference to a type previously inserted on this `Extensions`. + pub fn get_mut(&mut self) -> Option<&mut T> { + self.map + .get_mut(&TypeId::of::()) + .and_then(|boxed| (&mut **boxed as &mut (Any + 'static)).downcast_mut()) + } + + /// Remove a type from this `Extensions`. + /// + /// If a extension of this type existed, it will be returned. + pub fn remove(&mut self) -> Option { + self.map.remove(&TypeId::of::()).and_then(|boxed| { + (boxed as Box) + .downcast() + .ok() + .map(|boxed| *boxed) + }) + } + + /// Clear the `Extensions` of all inserted extensions. + #[inline] + pub fn clear(&mut self) { + self.map.clear(); + } +} + +impl fmt::Debug for Extensions { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Extensions").finish() + } +} + +#[test] +fn test_extensions() { + #[derive(Debug, PartialEq)] + struct MyType(i32); + + let mut extensions = Extensions::new(); + + extensions.insert(5i32); + extensions.insert(MyType(10)); + + assert_eq!(extensions.get(), Some(&5i32)); + assert_eq!(extensions.get_mut(), Some(&mut 5i32)); + + assert_eq!(extensions.remove::(), Some(5i32)); + assert!(extensions.get::().is_none()); + + assert_eq!(extensions.get::(), None); + assert_eq!(extensions.get(), Some(&MyType(10))); +} diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs new file mode 100644 index 000000000..b3a5a5d9f --- /dev/null +++ b/actix-http/src/h1/client.rs @@ -0,0 +1,241 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::io::{self, Write}; + +use actix_codec::{Decoder, Encoder}; +use bitflags::bitflags; +use bytes::{BufMut, Bytes, BytesMut}; +use http::header::{ + HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, +}; +use http::{Method, Version}; + +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; +use super::{decoder, encoder}; +use super::{Message, MessageType}; +use crate::body::BodyLength; +use crate::config::ServiceConfig; +use crate::error::{ParseError, PayloadError}; +use crate::helpers; +use crate::message::{ConnectionType, Head, MessagePool, RequestHead, ResponseHead}; + +bitflags! { + struct Flags: u8 { + const HEAD = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_1000; + const STREAM = 0b0001_0000; + } +} + +const AVERAGE_HEADER_SIZE: usize = 30; + +/// HTTP/1 Codec +pub struct ClientCodec { + inner: ClientCodecInner, +} + +/// HTTP/1 Payload Codec +pub struct ClientPayloadCodec { + inner: ClientCodecInner, +} + +struct ClientCodecInner { + config: ServiceConfig, + decoder: decoder::MessageDecoder, + payload: Option, + version: Version, + ctype: ConnectionType, + + // encoder part + flags: Flags, + headers_size: u32, + encoder: encoder::MessageEncoder, +} + +impl Default for ClientCodec { + fn default() -> Self { + ClientCodec::new(ServiceConfig::default()) + } +} + +impl ClientCodec { + /// Create HTTP/1 codec. + /// + /// `keepalive_enabled` how response `connection` header get generated. + pub fn new(config: ServiceConfig) -> Self { + let flags = if config.keep_alive_enabled() { + Flags::KEEPALIVE_ENABLED + } else { + Flags::empty() + }; + ClientCodec { + inner: ClientCodecInner { + config, + decoder: decoder::MessageDecoder::default(), + payload: None, + version: Version::HTTP_11, + ctype: ConnectionType::Close, + + flags, + headers_size: 0, + encoder: encoder::MessageEncoder::default(), + }, + } + } + + /// Check if request is upgrade + pub fn upgrade(&self) -> bool { + self.inner.ctype == ConnectionType::Upgrade + } + + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.inner.ctype == ConnectionType::KeepAlive + } + + /// Check last request's message type + pub fn message_type(&self) -> MessageType { + if self.inner.flags.contains(Flags::STREAM) { + MessageType::Stream + } else if self.inner.payload.is_none() { + MessageType::None + } else { + MessageType::Payload + } + } + + /// Convert message codec to a payload codec + pub fn into_payload_codec(self) -> ClientPayloadCodec { + ClientPayloadCodec { inner: self.inner } + } +} + +impl ClientPayloadCodec { + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.inner.ctype == ConnectionType::KeepAlive + } + + /// Transform payload codec to a message codec + pub fn into_message_codec(self) -> ClientCodec { + ClientCodec { inner: self.inner } + } +} + +impl Decoder for ClientCodec { + type Item = ResponseHead; + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set"); + + if let Some((req, payload)) = self.inner.decoder.decode(src)? { + if let Some(ctype) = req.ctype { + // do not use peer's keep-alive + self.inner.ctype = if ctype == ConnectionType::KeepAlive { + self.inner.ctype + } else { + ctype + }; + } + + if !self.inner.flags.contains(Flags::HEAD) { + match payload { + PayloadType::None => self.inner.payload = None, + PayloadType::Payload(pl) => self.inner.payload = Some(pl), + PayloadType::Stream(pl) => { + self.inner.payload = Some(pl); + self.inner.flags.insert(Flags::STREAM); + } + } + } else { + self.inner.payload = None; + } + Ok(Some(req)) + } else { + Ok(None) + } + } +} + +impl Decoder for ClientPayloadCodec { + type Item = Option; + type Error = PayloadError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + debug_assert!( + self.inner.payload.is_some(), + "Payload decoder is not specified" + ); + + Ok(match self.inner.payload.as_mut().unwrap().decode(src)? { + Some(PayloadItem::Chunk(chunk)) => Some(Some(chunk)), + Some(PayloadItem::Eof) => { + self.inner.payload.take(); + Some(None) + } + None => None, + }) + } +} + +impl Encoder for ClientCodec { + type Item = Message<(RequestHead, BodyLength)>; + type Error = io::Error; + + fn encode( + &mut self, + item: Self::Item, + dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + match item { + Message::Item((mut msg, length)) => { + let inner = &mut self.inner; + inner.version = msg.version; + inner.flags.set(Flags::HEAD, msg.method == Method::HEAD); + + // connection status + inner.ctype = match msg.connection_type() { + ConnectionType::KeepAlive => { + if inner.flags.contains(Flags::KEEPALIVE_ENABLED) { + ConnectionType::KeepAlive + } else { + ConnectionType::Close + } + } + ConnectionType::Upgrade => ConnectionType::Upgrade, + ConnectionType::Close => ConnectionType::Close, + }; + + inner.encoder.encode( + dst, + &mut msg, + false, + false, + inner.version, + length, + inner.ctype, + &inner.config, + )?; + } + Message::Chunk(Some(bytes)) => { + self.inner.encoder.encode_chunk(bytes.as_ref(), dst)?; + } + Message::Chunk(None) => { + self.inner.encoder.encode_eof(dst)?; + } + } + Ok(()) + } +} + +pub struct Writer<'a>(pub &'a mut BytesMut); + +impl<'a> io::Write for Writer<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/actix-http/src/h1/codec.rs b/actix-http/src/h1/codec.rs new file mode 100644 index 000000000..c66364c02 --- /dev/null +++ b/actix-http/src/h1/codec.rs @@ -0,0 +1,242 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::fmt; +use std::io::{self, Write}; + +use actix_codec::{Decoder, Encoder}; +use bitflags::bitflags; +use bytes::{BufMut, Bytes, BytesMut}; +use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; +use http::{Method, StatusCode, Version}; + +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; +use super::{decoder, encoder}; +use super::{Message, MessageType}; +use crate::body::BodyLength; +use crate::config::ServiceConfig; +use crate::error::ParseError; +use crate::helpers; +use crate::message::{ConnectionType, Head, ResponseHead}; +use crate::request::Request; +use crate::response::Response; + +bitflags! { + struct Flags: u8 { + const HEAD = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_1000; + const STREAM = 0b0001_0000; + } +} + +const AVERAGE_HEADER_SIZE: usize = 30; + +/// HTTP/1 Codec +pub struct Codec { + config: ServiceConfig, + decoder: decoder::MessageDecoder, + payload: Option, + version: Version, + ctype: ConnectionType, + + // encoder part + flags: Flags, + headers_size: u32, + encoder: encoder::MessageEncoder>, +} + +impl Default for Codec { + fn default() -> Self { + Codec::new(ServiceConfig::default()) + } +} + +impl fmt::Debug for Codec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "h1::Codec({:?})", self.flags) + } +} + +impl Codec { + /// Create HTTP/1 codec. + /// + /// `keepalive_enabled` how response `connection` header get generated. + pub fn new(config: ServiceConfig) -> Self { + let flags = if config.keep_alive_enabled() { + Flags::KEEPALIVE_ENABLED + } else { + Flags::empty() + }; + Codec { + config, + decoder: decoder::MessageDecoder::default(), + payload: None, + version: Version::HTTP_11, + ctype: ConnectionType::Close, + + flags, + headers_size: 0, + encoder: encoder::MessageEncoder::default(), + } + } + + /// Check if request is upgrade + pub fn upgrade(&self) -> bool { + self.ctype == ConnectionType::Upgrade + } + + /// Check if last response is keep-alive + pub fn keepalive(&self) -> bool { + self.ctype == ConnectionType::KeepAlive + } + + /// Check last request's message type + pub fn message_type(&self) -> MessageType { + if self.flags.contains(Flags::STREAM) { + MessageType::Stream + } else if self.payload.is_none() { + MessageType::None + } else { + MessageType::Payload + } + } +} + +impl Decoder for Codec { + type Item = Message; + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if self.payload.is_some() { + Ok(match self.payload.as_mut().unwrap().decode(src)? { + Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))), + Some(PayloadItem::Eof) => { + self.payload.take(); + Some(Message::Chunk(None)) + } + None => None, + }) + } else if let Some((req, payload)) = self.decoder.decode(src)? { + let head = req.head(); + self.flags.set(Flags::HEAD, head.method == Method::HEAD); + self.version = head.version; + self.ctype = head.connection_type(); + if self.ctype == ConnectionType::KeepAlive + && !self.flags.contains(Flags::KEEPALIVE_ENABLED) + { + self.ctype = ConnectionType::Close + } + match payload { + PayloadType::None => self.payload = None, + PayloadType::Payload(pl) => self.payload = Some(pl), + PayloadType::Stream(pl) => { + self.payload = Some(pl); + self.flags.insert(Flags::STREAM); + } + } + Ok(Some(Message::Item(req))) + } else { + Ok(None) + } + } +} + +impl Encoder for Codec { + type Item = Message<(Response<()>, BodyLength)>; + type Error = io::Error; + + fn encode( + &mut self, + item: Self::Item, + dst: &mut BytesMut, + ) -> Result<(), Self::Error> { + match item { + Message::Item((mut res, length)) => { + // set response version + res.head_mut().version = self.version; + + // connection status + self.ctype = if let Some(ct) = res.head().ctype { + if ct == ConnectionType::KeepAlive { + self.ctype + } else { + ct + } + } else { + self.ctype + }; + + // encode message + let len = dst.len(); + self.encoder.encode( + dst, + &mut res, + self.flags.contains(Flags::HEAD), + self.flags.contains(Flags::STREAM), + self.version, + length, + self.ctype, + &self.config, + )?; + self.headers_size = (dst.len() - len) as u32; + } + Message::Chunk(Some(bytes)) => { + self.encoder.encode_chunk(bytes.as_ref(), dst)?; + } + Message::Chunk(None) => { + self.encoder.encode_eof(dst)?; + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{cmp, io}; + + use actix_codec::{AsyncRead, AsyncWrite}; + use bytes::{Buf, Bytes, BytesMut}; + use http::{Method, Version}; + + use super::*; + use crate::error::ParseError; + use crate::h1::Message; + use crate::httpmessage::HttpMessage; + use crate::request::Request; + + #[test] + fn test_http_request_chunked_payload_and_next_message() { + let mut codec = Codec::default(); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let item = codec.decode(&mut buf).unwrap().unwrap(); + let req = item.message(); + + assert_eq!(req.method(), Method::GET); + assert!(req.chunked().unwrap()); + + buf.extend( + b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ + POST /test2 HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n" + .iter(), + ); + + let msg = codec.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + + let msg = codec.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"line"); + + let msg = codec.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + + // decode next message + let item = codec.decode(&mut buf).unwrap().unwrap(); + let req = item.message(); + assert_eq!(*req.method(), Method::POST); + assert!(req.chunked().unwrap()); + } +} diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs new file mode 100644 index 000000000..a1b221c06 --- /dev/null +++ b/actix-http/src/h1/decoder.rs @@ -0,0 +1,1170 @@ +use std::marker::PhantomData; +use std::{io, mem}; + +use actix_codec::Decoder; +use bytes::{Bytes, BytesMut}; +use futures::{Async, Poll}; +use http::header::{HeaderName, HeaderValue}; +use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version}; +use httparse; +use log::{debug, error, trace}; + +use crate::error::ParseError; +use crate::message::{ConnectionType, ResponseHead}; +use crate::request::Request; + +const MAX_BUFFER_SIZE: usize = 131_072; +const MAX_HEADERS: usize = 96; + +/// Incoming messagd decoder +pub(crate) struct MessageDecoder(PhantomData); + +#[derive(Debug)] +/// Incoming request type +pub(crate) enum PayloadType { + None, + Payload(PayloadDecoder), + Stream(PayloadDecoder), +} + +impl Default for MessageDecoder { + fn default() -> Self { + MessageDecoder(PhantomData) + } +} + +impl Decoder for MessageDecoder { + type Item = (T, PayloadType); + type Error = ParseError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + T::decode(src) + } +} + +pub(crate) enum PayloadLength { + Payload(PayloadType), + Upgrade, + None, +} + +pub(crate) trait MessageType: Sized { + fn set_connection_type(&mut self, ctype: Option); + + fn headers_mut(&mut self) -> &mut HeaderMap; + + fn decode(src: &mut BytesMut) -> Result, ParseError>; + + fn set_headers( + &mut self, + slice: &Bytes, + raw_headers: &[HeaderIndex], + ) -> Result { + let mut ka = None; + let mut has_upgrade = false; + let mut chunked = false; + let mut content_length = None; + + { + let headers = self.headers_mut(); + + for idx in raw_headers.iter() { + if let Ok(name) = HeaderName::from_bytes(&slice[idx.name.0..idx.name.1]) + { + // Unsafe: httparse check header value for valid utf-8 + let value = unsafe { + HeaderValue::from_shared_unchecked( + slice.slice(idx.value.0, idx.value.1), + ) + }; + match name { + header::CONTENT_LENGTH => { + if let Ok(s) = value.to_str() { + if let Ok(len) = s.parse::() { + content_length = Some(len); + } else { + debug!("illegal Content-Length: {:?}", s); + return Err(ParseError::Header); + } + } else { + debug!("illegal Content-Length: {:?}", value); + return Err(ParseError::Header); + } + } + // transfer-encoding + header::TRANSFER_ENCODING => { + if let Ok(s) = value.to_str().map(|s| s.trim()) { + chunked = s.eq_ignore_ascii_case("chunked"); + } else { + return Err(ParseError::Header); + } + } + // connection keep-alive state + header::CONNECTION => { + ka = if let Ok(conn) = value.to_str().map(|conn| conn.trim()) + { + if conn.eq_ignore_ascii_case("keep-alive") { + Some(ConnectionType::KeepAlive) + } else if conn.eq_ignore_ascii_case("close") { + Some(ConnectionType::Close) + } else if conn.eq_ignore_ascii_case("upgrade") { + Some(ConnectionType::Upgrade) + } else { + None + } + } else { + None + }; + } + header::UPGRADE => { + has_upgrade = true; + // check content-length, some clients (dart) + // sends "content-length: 0" with websocket upgrade + if let Ok(val) = value.to_str().map(|val| val.trim()) { + if val.eq_ignore_ascii_case("websocket") { + content_length = None; + } + } + } + _ => (), + } + + headers.append(name, value); + } else { + return Err(ParseError::Header); + } + } + } + self.set_connection_type(ka); + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + if chunked { + // Chunked encoding + Ok(PayloadLength::Payload(PayloadType::Payload( + PayloadDecoder::chunked(), + ))) + } else if let Some(len) = content_length { + // Content-Length + Ok(PayloadLength::Payload(PayloadType::Payload( + PayloadDecoder::length(len), + ))) + } else if has_upgrade { + Ok(PayloadLength::Upgrade) + } else { + Ok(PayloadLength::None) + } + } +} + +impl MessageType for Request { + fn set_connection_type(&mut self, ctype: Option) { + self.head_mut().ctype = ctype; + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head_mut().headers + } + + fn decode(src: &mut BytesMut) -> Result, ParseError> { + // Unsafe: we read only this data only after httparse parses headers into. + // performance bump for pipeline benchmarks. + let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; + + let (len, method, uri, ver, h_len) = { + let mut parsed: [httparse::Header; MAX_HEADERS] = + unsafe { mem::uninitialized() }; + + let mut req = httparse::Request::new(&mut parsed); + match req.parse(src)? { + httparse::Status::Complete(len) => { + let method = Method::from_bytes(req.method.unwrap().as_bytes()) + .map_err(|_| ParseError::Method)?; + let uri = Uri::try_from(req.path.unwrap())?; + let version = if req.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + HeaderIndex::record(src, req.headers, &mut headers); + + (len, method, uri, version, req.headers.len()) + } + httparse::Status::Partial => return Ok(None), + } + }; + + let mut msg = Request::new(); + + // convert headers + let len = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?; + + // payload decoder + let decoder = match len { + PayloadLength::Payload(pl) => pl, + PayloadLength::Upgrade => { + // upgrade(websocket) + PayloadType::Stream(PayloadDecoder::eof()) + } + PayloadLength::None => { + if method == Method::CONNECT { + PayloadType::Stream(PayloadDecoder::eof()) + } else if src.len() >= MAX_BUFFER_SIZE { + trace!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + PayloadType::None + } + } + }; + + let head = msg.head_mut(); + head.uri = uri; + head.method = method; + head.version = ver; + + Ok(Some((msg, decoder))) + } +} + +impl MessageType for ResponseHead { + fn set_connection_type(&mut self, ctype: Option) { + self.ctype = ctype; + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + fn decode(src: &mut BytesMut) -> Result, ParseError> { + // Unsafe: we read only this data only after httparse parses headers into. + // performance bump for pipeline benchmarks. + let mut headers: [HeaderIndex; MAX_HEADERS] = unsafe { mem::uninitialized() }; + + let (len, ver, status, h_len) = { + let mut parsed: [httparse::Header; MAX_HEADERS] = + unsafe { mem::uninitialized() }; + + let mut res = httparse::Response::new(&mut parsed); + match res.parse(src)? { + httparse::Status::Complete(len) => { + let version = if res.version.unwrap() == 1 { + Version::HTTP_11 + } else { + Version::HTTP_10 + }; + let status = StatusCode::from_u16(res.code.unwrap()) + .map_err(|_| ParseError::Status)?; + HeaderIndex::record(src, res.headers, &mut headers); + + (len, version, status, res.headers.len()) + } + httparse::Status::Partial => return Ok(None), + } + }; + + let mut msg = ResponseHead::default(); + + // convert headers + let len = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?; + + // message payload + let decoder = if let PayloadLength::Payload(pl) = len { + pl + } else if status == StatusCode::SWITCHING_PROTOCOLS { + // switching protocol or connect + PayloadType::Stream(PayloadDecoder::eof()) + } else if src.len() >= MAX_BUFFER_SIZE { + error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + return Err(ParseError::TooLarge); + } else { + PayloadType::None + }; + + msg.status = status; + msg.version = ver; + + Ok(Some((msg, decoder))) + } +} + +#[derive(Clone, Copy)] +pub(crate) struct HeaderIndex { + pub(crate) name: (usize, usize), + pub(crate) value: (usize, usize), +} + +impl HeaderIndex { + pub(crate) fn record( + bytes: &[u8], + headers: &[httparse::Header], + indices: &mut [HeaderIndex], + ) { + let bytes_ptr = bytes.as_ptr() as usize; + for (header, indices) in headers.iter().zip(indices.iter_mut()) { + let name_start = header.name.as_ptr() as usize - bytes_ptr; + let name_end = name_start + header.name.len(); + indices.name = (name_start, name_end); + let value_start = header.value.as_ptr() as usize - bytes_ptr; + let value_end = value_start + header.value.len(); + indices.value = (value_start, value_end); + } + } +} + +#[derive(Debug, Clone)] +/// Http payload item +pub enum PayloadItem { + Chunk(Bytes), + Eof, +} + +/// Decoders to handle different Transfer-Encodings. +/// +/// If a message body does not include a Transfer-Encoding, it *should* +/// include a Content-Length header. +#[derive(Debug, Clone, PartialEq)] +pub struct PayloadDecoder { + kind: Kind, +} + +impl PayloadDecoder { + pub fn length(x: u64) -> PayloadDecoder { + PayloadDecoder { + kind: Kind::Length(x), + } + } + + pub fn chunked() -> PayloadDecoder { + PayloadDecoder { + kind: Kind::Chunked(ChunkedState::Size, 0), + } + } + + pub fn eof() -> PayloadDecoder { + PayloadDecoder { kind: Kind::Eof } + } +} + +#[derive(Debug, Clone, PartialEq)] +enum Kind { + /// A Reader used when a Content-Length header is passed with a positive + /// integer. + Length(u64), + /// A Reader used when Transfer-Encoding is `chunked`. + Chunked(ChunkedState, u64), + /// A Reader used for responses that don't indicate a length or chunked. + /// + /// Note: This should only used for `Response`s. It is illegal for a + /// `Request` to be made with both `Content-Length` and + /// `Transfer-Encoding: chunked` missing, as explained from the spec: + /// + /// > If a Transfer-Encoding header field is present in a response and + /// > the chunked transfer coding is not the final encoding, the + /// > message body length is determined by reading the connection until + /// > it is closed by the server. If a Transfer-Encoding header field + /// > is present in a request and the chunked transfer coding is not + /// > the final encoding, the message body length cannot be determined + /// > reliably; the server MUST respond with the 400 (Bad Request) + /// > status code and then close the connection. + Eof, +} + +#[derive(Debug, PartialEq, Clone)] +enum ChunkedState { + Size, + SizeLws, + Extension, + SizeLf, + Body, + BodyCr, + BodyLf, + EndCr, + EndLf, + End, +} + +impl Decoder for PayloadDecoder { + type Item = PayloadItem; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match self.kind { + Kind::Length(ref mut remaining) => { + if *remaining == 0 { + Ok(Some(PayloadItem::Eof)) + } else { + if src.is_empty() { + return Ok(None); + } + let len = src.len() as u64; + let buf; + if *remaining > len { + buf = src.take().freeze(); + *remaining -= len; + } else { + buf = src.split_to(*remaining as usize).freeze(); + *remaining = 0; + }; + trace!("Length read: {}", buf.len()); + Ok(Some(PayloadItem::Chunk(buf))) + } + } + Kind::Chunked(ref mut state, ref mut size) => { + loop { + let mut buf = None; + // advances the chunked state + *state = match state.step(src, size, &mut buf)? { + Async::NotReady => return Ok(None), + Async::Ready(state) => state, + }; + if *state == ChunkedState::End { + trace!("End of chunked stream"); + return Ok(Some(PayloadItem::Eof)); + } + if let Some(buf) = buf { + return Ok(Some(PayloadItem::Chunk(buf))); + } + if src.is_empty() { + return Ok(None); + } + } + } + Kind::Eof => { + if src.is_empty() { + Ok(None) + } else { + Ok(Some(PayloadItem::Chunk(src.take().freeze()))) + } + } + } + } +} + +macro_rules! byte ( + ($rdr:ident) => ({ + if $rdr.len() > 0 { + let b = $rdr[0]; + $rdr.split_to(1); + b + } else { + return Ok(Async::NotReady) + } + }) +); + +impl ChunkedState { + fn step( + &self, + body: &mut BytesMut, + size: &mut u64, + buf: &mut Option, + ) -> Poll { + use self::ChunkedState::*; + match *self { + Size => ChunkedState::read_size(body, size), + SizeLws => ChunkedState::read_size_lws(body), + Extension => ChunkedState::read_extension(body), + SizeLf => ChunkedState::read_size_lf(body, size), + Body => ChunkedState::read_body(body, size, buf), + BodyCr => ChunkedState::read_body_cr(body), + BodyLf => ChunkedState::read_body_lf(body), + EndCr => ChunkedState::read_end_cr(body), + EndLf => ChunkedState::read_end_lf(body), + End => Ok(Async::Ready(ChunkedState::End)), + } + } + fn read_size(rdr: &mut BytesMut, size: &mut u64) -> Poll { + let radix = 16; + match byte!(rdr) { + b @ b'0'...b'9' => { + *size *= radix; + *size += u64::from(b - b'0'); + } + b @ b'a'...b'f' => { + *size *= radix; + *size += u64::from(b + 10 - b'a'); + } + b @ b'A'...b'F' => { + *size *= radix; + *size += u64::from(b + 10 - b'A'); + } + b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)), + b';' => return Ok(Async::Ready(ChunkedState::Extension)), + b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)), + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size line: Invalid Size", + )); + } + } + Ok(Async::Ready(ChunkedState::Size)) + } + fn read_size_lws(rdr: &mut BytesMut) -> Poll { + trace!("read_size_lws"); + match byte!(rdr) { + // LWS can follow the chunk size, but no more digits can come + b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)), + b';' => Ok(Async::Ready(ChunkedState::Extension)), + b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size linear white space", + )), + } + } + fn read_extension(rdr: &mut BytesMut) -> Poll { + match byte!(rdr) { + b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)), + _ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions + } + } + fn read_size_lf( + rdr: &mut BytesMut, + size: &mut u64, + ) -> Poll { + match byte!(rdr) { + b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)), + b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk size LF", + )), + } + } + + fn read_body( + rdr: &mut BytesMut, + rem: &mut u64, + buf: &mut Option, + ) -> Poll { + trace!("Chunked read, remaining={:?}", rem); + + let len = rdr.len() as u64; + if len == 0 { + Ok(Async::Ready(ChunkedState::Body)) + } else { + let slice; + if *rem > len { + slice = rdr.take().freeze(); + *rem -= len; + } else { + slice = rdr.split_to(*rem as usize).freeze(); + *rem = 0; + } + *buf = Some(slice); + if *rem > 0 { + Ok(Async::Ready(ChunkedState::Body)) + } else { + Ok(Async::Ready(ChunkedState::BodyCr)) + } + } + } + + fn read_body_cr(rdr: &mut BytesMut) -> Poll { + match byte!(rdr) { + b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body CR", + )), + } + } + fn read_body_lf(rdr: &mut BytesMut) -> Poll { + match byte!(rdr) { + b'\n' => Ok(Async::Ready(ChunkedState::Size)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk body LF", + )), + } + } + fn read_end_cr(rdr: &mut BytesMut) -> Poll { + match byte!(rdr) { + b'\r' => Ok(Async::Ready(ChunkedState::EndLf)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end CR", + )), + } + } + fn read_end_lf(rdr: &mut BytesMut) -> Poll { + match byte!(rdr) { + b'\n' => Ok(Async::Ready(ChunkedState::End)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Invalid chunk end LF", + )), + } + } +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + use http::{Method, Version}; + + use super::*; + use crate::error::ParseError; + use crate::httpmessage::HttpMessage; + use crate::message::Head; + + impl PayloadType { + fn unwrap(self) -> PayloadDecoder { + match self { + PayloadType::Payload(pl) => pl, + _ => panic!(), + } + } + + fn is_unhandled(&self) -> bool { + match self { + PayloadType::Stream(_) => true, + _ => false, + } + } + } + + impl PayloadItem { + fn chunk(self) -> Bytes { + match self { + PayloadItem::Chunk(chunk) => chunk, + _ => panic!("error"), + } + } + fn eof(&self) -> bool { + match *self { + PayloadItem::Eof => true, + _ => false, + } + } + } + + macro_rules! parse_ready { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Ok(Some((msg, _))) => msg, + Ok(_) => unreachable!("Eof during parsing http request"), + Err(err) => unreachable!("Error during parsing http request: {:?}", err), + } + }}; + } + + macro_rules! expect_parse_err { + ($e:expr) => {{ + match MessageDecoder::::default().decode($e) { + Err(err) => match err { + ParseError::Io(_) => unreachable!("Parse error expected"), + _ => (), + }, + _ => unreachable!("Error expected"), + } + }}; + } + + #[test] + fn test_parse() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); + + let mut reader = MessageDecoder::::default(); + match reader.decode(&mut buf) { + Ok(Some((req, _))) => { + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + } + Ok(_) | Err(_) => unreachable!("Error during parsing http request"), + } + } + + #[test] + fn test_parse_partial() { + let mut buf = BytesMut::from("PUT /test HTTP/1"); + + let mut reader = MessageDecoder::::default(); + assert!(reader.decode(&mut buf).unwrap().is_none()); + + buf.extend(b".1\r\n\r\n"); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::PUT); + assert_eq!(req.path(), "/test"); + } + + #[test] + fn test_parse_post() { + let mut buf = BytesMut::from("POST /test2 HTTP/1.0\r\n\r\n"); + + let mut reader = MessageDecoder::::default(); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_10); + assert_eq!(*req.method(), Method::POST); + assert_eq!(req.path(), "/test2"); + } + + #[test] + fn test_parse_body() { + let mut buf = + BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"body" + ); + } + + #[test] + fn test_parse_body_crlf() { + let mut buf = + BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"body" + ); + } + + #[test] + fn test_parse_partial_eof() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); + let mut reader = MessageDecoder::::default(); + assert!(reader.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r\n"); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + } + + #[test] + fn test_headers_split_field() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); + + let mut reader = MessageDecoder::::default(); + assert! { reader.decode(&mut buf).unwrap().is_none() } + + buf.extend(b"t"); + assert! { reader.decode(&mut buf).unwrap().is_none() } + + buf.extend(b"es"); + assert! { reader.decode(&mut buf).unwrap().is_none() } + + buf.extend(b"t: value\r\n\r\n"); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.version(), Version::HTTP_11); + assert_eq!(*req.method(), Method::GET); + assert_eq!(req.path(), "/test"); + assert_eq!(req.headers().get("test").unwrap().as_bytes(), b"value"); + } + + #[test] + fn test_headers_multi_value() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + Set-Cookie: c1=cookie1\r\n\ + Set-Cookie: c2=cookie2\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + + let val: Vec<_> = req + .headers() + .get_all("Set-Cookie") + .iter() + .map(|v| v.to_str().unwrap().to_owned()) + .collect(); + assert_eq!(val[0], "c1=cookie1"); + assert_eq!(val[1], "c2=cookie2"); + } + + #[test] + fn test_conn_default_1_0() { + let mut buf = BytesMut::from("GET /test HTTP/1.0\r\n\r\n"); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::Close); + } + + #[test] + fn test_conn_default_1_1() { + let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); + } + + #[test] + fn test_conn_close() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: close\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().ctype, Some(ConnectionType::Close)); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: Close\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().ctype, Some(ConnectionType::Close)); + } + + #[test] + fn test_conn_close_1_0() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: close\r\n\r\n", + ); + + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().ctype, Some(ConnectionType::Close)); + } + + #[test] + fn test_conn_keep_alive_1_0() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: keep-alive\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().ctype, Some(ConnectionType::KeepAlive)); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: Keep-Alive\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().ctype, Some(ConnectionType::KeepAlive)); + } + + #[test] + fn test_conn_keep_alive_1_1() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: keep-alive\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().ctype, Some(ConnectionType::KeepAlive)); + } + + #[test] + fn test_conn_other_1_0() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.0\r\n\ + connection: other\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().connection_type(), ConnectionType::Close); + } + + #[test] + fn test_conn_other_1_1() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: other\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!(req.head().ctype, None); + assert_eq!(req.head().connection_type(), ConnectionType::KeepAlive); + } + + #[test] + fn test_conn_upgrade() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + upgrade: websockets\r\n\ + connection: upgrade\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert!(req.upgrade()); + assert_eq!(req.head().ctype, Some(ConnectionType::Upgrade)); + + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + upgrade: Websockets\r\n\ + connection: Upgrade\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert!(req.upgrade()); + assert_eq!(req.head().ctype, Some(ConnectionType::Upgrade)); + } + + #[test] + fn test_conn_upgrade_connect_method() { + let mut buf = BytesMut::from( + "CONNECT /test HTTP/1.1\r\n\ + content-type: text/plain\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert!(req.upgrade()); + } + + #[test] + fn test_request_chunked() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + if let Ok(val) = req.chunked() { + assert!(val); + } else { + unreachable!("Error"); + } + + // type in chunked + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chnked\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + if let Ok(val) = req.chunked() { + assert!(!val); + } else { + unreachable!("Error"); + } + } + + #[test] + fn test_headers_content_length_err_1() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + content-length: line\r\n\r\n", + ); + + expect_parse_err!(&mut buf) + } + + #[test] + fn test_headers_content_length_err_2() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + content-length: -1\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_invalid_header() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + test line\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_invalid_name() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + test[]: line\r\n\r\n", + ); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_bad_status_line() { + let mut buf = BytesMut::from("getpath \r\n\r\n"); + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_upgrade() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + connection: upgrade\r\n\ + upgrade: websocket\r\n\r\n\ + some raw data", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + assert_eq!(req.head().ctype, Some(ConnectionType::Upgrade)); + assert!(req.upgrade()); + assert!(pl.is_unhandled()); + } + + #[test] + fn test_http_request_parser_utf8() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + x-test: тест\r\n\r\n", + ); + let req = parse_ready!(&mut buf); + + assert_eq!( + req.headers().get("x-test").unwrap().as_bytes(), + "тест".as_bytes() + ); + } + + #[test] + fn test_http_request_parser_two_slashes() { + let mut buf = BytesMut::from("GET //path HTTP/1.1\r\n\r\n"); + let req = parse_ready!(&mut buf); + + assert_eq!(req.path(), "//path"); + } + + #[test] + fn test_http_request_parser_bad_method() { + let mut buf = BytesMut::from("!12%()+=~$ /get HTTP/1.1\r\n\r\n"); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_parser_bad_version() { + let mut buf = BytesMut::from("GET //get HT/11\r\n\r\n"); + + expect_parse_err!(&mut buf); + } + + #[test] + fn test_http_request_chunked_payload() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"data" + ); + assert_eq!( + pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(), + b"line" + ); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn test_http_request_chunked_payload_and_next_message() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend( + b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\ + POST /test2 HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n" + .iter(), + ); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"line"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + + let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); + assert!(req.chunked().unwrap()); + assert_eq!(*req.method(), Method::POST); + assert!(req.chunked().unwrap()); + } + + #[test] + fn test_http_request_chunked_payload_chunks() { + let mut buf = BytesMut::from( + "GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n", + ); + + let mut reader = MessageDecoder::::default(); + let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(req.chunked().unwrap()); + + buf.extend(b"4\r\n1111\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"1111"); + + buf.extend(b"4\r\ndata\r"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"data"); + + buf.extend(b"\n4"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + buf.extend(b"\n"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"li"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"li"); + + //trailers + //buf.feed_data("test: test\r\n"); + //not_ready!(reader.parse(&mut buf, &mut readbuf)); + + buf.extend(b"ne\r\n0\r\n"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg.chunk().as_ref(), b"ne"); + assert!(pl.decode(&mut buf).unwrap().is_none()); + + buf.extend(b"\r\n"); + assert!(pl.decode(&mut buf).unwrap().unwrap().eof()); + } + + #[test] + fn test_parse_chunked_payload_chunk_extension() { + let mut buf = BytesMut::from( + &"GET /test HTTP/1.1\r\n\ + transfer-encoding: chunked\r\n\r\n"[..], + ); + + let mut reader = MessageDecoder::::default(); + let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); + let mut pl = pl.unwrap(); + assert!(msg.chunked().unwrap()); + + buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"data")); + let chunk = pl.decode(&mut buf).unwrap().unwrap().chunk(); + assert_eq!(chunk, Bytes::from_static(b"line")); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert!(msg.eof()); + } +} diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs new file mode 100644 index 000000000..09a17fcf7 --- /dev/null +++ b/actix-http/src/h1/dispatcher.rs @@ -0,0 +1,636 @@ +use std::collections::VecDeque; +use std::fmt::Debug; +use std::mem; +use std::time::Instant; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::Service; +use actix_utils::cloneable::CloneableService; +use bitflags::bitflags; +use futures::{Async, Future, Poll, Sink, Stream}; +use log::{debug, error, trace}; +use tokio_timer::Delay; + +use crate::body::{Body, BodyLength, MessageBody, ResponseBody}; +use crate::config::ServiceConfig; +use crate::error::DispatchError; +use crate::error::{ParseError, PayloadError}; +use crate::request::Request; +use crate::response::Response; + +use super::codec::Codec; +use super::payload::{Payload, PayloadSender, PayloadStatus, PayloadWriter}; +use super::{Message, MessageType}; + +const MAX_PIPELINED_MESSAGES: usize = 16; + +bitflags! { + pub struct Flags: u8 { + const STARTED = 0b0000_0001; + const KEEPALIVE_ENABLED = 0b0000_0010; + const KEEPALIVE = 0b0000_0100; + const POLLED = 0b0000_1000; + const SHUTDOWN = 0b0010_0000; + const DISCONNECTED = 0b0100_0000; + const DROPPING = 0b1000_0000; + } +} + +/// Dispatcher for HTTP/1.1 protocol +pub struct Dispatcher + 'static, B: MessageBody> +where + S::Error: Debug, +{ + inner: Option>, +} + +struct InnerDispatcher + 'static, B: MessageBody> +where + S::Error: Debug, +{ + service: CloneableService, + flags: Flags, + framed: Framed, + error: Option, + config: ServiceConfig, + + state: State, + payload: Option, + messages: VecDeque, + + ka_expire: Instant, + ka_timer: Option, +} + +enum DispatcherMessage { + Item(Request), + Error(Response<()>), +} + +enum State, B: MessageBody> { + None, + ServiceCall(S::Future), + SendPayload(ResponseBody), +} + +impl, B: MessageBody> State { + fn is_empty(&self) -> bool { + if let State::None = self { + true + } else { + false + } + } +} + +impl Dispatcher +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + /// Create http/1 dispatcher. + pub fn new(stream: T, config: ServiceConfig, service: CloneableService) -> Self { + Dispatcher::with_timeout( + Framed::new(stream, Codec::new(config.clone())), + config, + None, + service, + ) + } + + /// Create http/1 dispatcher with slow request timeout. + pub fn with_timeout( + framed: Framed, + config: ServiceConfig, + timeout: Option, + service: CloneableService, + ) -> Self { + let keepalive = config.keep_alive_enabled(); + let flags = if keepalive { + Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED + } else { + Flags::empty() + }; + + // keep-alive timer + let (ka_expire, ka_timer) = if let Some(delay) = timeout { + (delay.deadline(), Some(delay)) + } else if let Some(delay) = config.keep_alive_timer() { + (delay.deadline(), Some(delay)) + } else { + (config.now(), None) + }; + + Dispatcher { + inner: Some(InnerDispatcher { + framed, + payload: None, + state: State::None, + error: None, + messages: VecDeque::new(), + service, + flags, + config, + ka_expire, + ka_timer, + }), + } + } +} + +impl InnerDispatcher +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + fn can_read(&self) -> bool { + if self.flags.contains(Flags::DISCONNECTED) { + return false; + } + + if let Some(ref info) = self.payload { + info.need_read() == PayloadStatus::Read + } else { + true + } + } + + // if checked is set to true, delay disconnect until all tasks have finished. + fn client_disconnected(&mut self) { + self.flags.insert(Flags::DISCONNECTED); + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete(None)); + } + } + + /// Flush stream + fn poll_flush(&mut self) -> Poll { + if !self.framed.is_write_buf_empty() { + match self.framed.poll_complete() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(err) => { + debug!("Error sending data: {}", err); + Err(err.into()) + } + Ok(Async::Ready(_)) => { + // if payload is not consumed we can not use connection + if self.payload.is_some() && self.state.is_empty() { + return Err(DispatchError::PayloadIsNotConsumed); + } + Ok(Async::Ready(true)) + } + } + } else { + Ok(Async::Ready(false)) + } + } + + fn send_response( + &mut self, + message: Response<()>, + body: ResponseBody, + ) -> Result, DispatchError> { + self.framed + .force_send(Message::Item((message, body.length()))) + .map_err(|err| { + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::Incomplete(None)); + } + DispatchError::Io(err) + })?; + + self.flags + .set(Flags::KEEPALIVE, self.framed.get_codec().keepalive()); + match body.length() { + BodyLength::None | BodyLength::Empty => Ok(State::None), + _ => Ok(State::SendPayload(body)), + } + } + + fn poll_response(&mut self) -> Result<(), DispatchError> { + let mut retry = self.can_read(); + loop { + let state = match mem::replace(&mut self.state, State::None) { + State::None => match self.messages.pop_front() { + Some(DispatcherMessage::Item(req)) => { + Some(self.handle_request(req)?) + } + Some(DispatcherMessage::Error(res)) => { + self.send_response(res, ResponseBody::Other(Body::Empty))?; + None + } + None => None, + }, + State::ServiceCall(mut fut) => match fut.poll() { + Ok(Async::Ready(res)) => { + let (res, body) = res.into().replace_body(()); + Some(self.send_response(res, body)?) + } + Ok(Async::NotReady) => { + self.state = State::ServiceCall(fut); + None + } + Err(_e) => { + let res: Response = Response::InternalServerError().finish(); + let (res, body) = res.replace_body(()); + Some(self.send_response(res, body.into_body())?) + } + }, + State::SendPayload(mut stream) => { + loop { + if !self.framed.is_write_buf_full() { + match stream + .poll_next() + .map_err(|_| DispatchError::Unknown)? + { + Async::Ready(Some(item)) => { + self.framed + .force_send(Message::Chunk(Some(item)))?; + continue; + } + Async::Ready(None) => { + self.framed.force_send(Message::Chunk(None))?; + } + Async::NotReady => { + self.state = State::SendPayload(stream); + return Ok(()); + } + } + } else { + self.state = State::SendPayload(stream); + return Ok(()); + } + break; + } + None + } + }; + + match state { + Some(state) => self.state = state, + None => { + // if read-backpressure is enabled and we consumed some data. + // we may read more data and retry + if !retry && self.can_read() && self.poll_request()? { + retry = self.can_read(); + continue; + } + break; + } + } + } + + Ok(()) + } + + fn handle_request(&mut self, req: Request) -> Result, DispatchError> { + let mut task = self.service.call(req); + match task.poll() { + Ok(Async::Ready(res)) => { + let (res, body) = res.into().replace_body(()); + self.send_response(res, body) + } + Ok(Async::NotReady) => Ok(State::ServiceCall(task)), + Err(_e) => { + let res: Response = Response::InternalServerError().finish(); + let (res, body) = res.replace_body(()); + self.send_response(res, body.into_body()) + } + } + } + + /// Process one incoming requests + pub(self) fn poll_request(&mut self) -> Result { + // limit a mount of non processed requests + if self.messages.len() >= MAX_PIPELINED_MESSAGES { + return Ok(false); + } + + let mut updated = false; + loop { + match self.framed.poll() { + Ok(Async::Ready(Some(msg))) => { + updated = true; + self.flags.insert(Flags::STARTED); + + match msg { + Message::Item(mut req) => { + match self.framed.get_codec().message_type() { + MessageType::Payload | MessageType::Stream => { + let (ps, pl) = Payload::create(false); + let (req1, _) = + req.replace_payload(crate::Payload::H1(pl)); + req = req1; + self.payload = Some(ps); + } + //MessageType::Stream => { + // self.unhandled = Some(req); + // return Ok(updated); + //} + _ => (), + } + + // handle request early + if self.state.is_empty() { + self.state = self.handle_request(req)?; + } else { + self.messages.push_back(DispatcherMessage::Item(req)); + } + } + Message::Chunk(Some(chunk)) => { + if let Some(ref mut payload) = self.payload { + payload.feed_data(chunk); + } else { + error!( + "Internal server error: unexpected payload chunk" + ); + self.flags.insert(Flags::DISCONNECTED); + self.messages.push_back(DispatcherMessage::Error( + Response::InternalServerError().finish().drop_body(), + )); + self.error = Some(DispatchError::InternalError); + break; + } + } + Message::Chunk(None) => { + if let Some(mut payload) = self.payload.take() { + payload.feed_eof(); + } else { + error!("Internal server error: unexpected eof"); + self.flags.insert(Flags::DISCONNECTED); + self.messages.push_back(DispatcherMessage::Error( + Response::InternalServerError().finish().drop_body(), + )); + self.error = Some(DispatchError::InternalError); + break; + } + } + } + } + Ok(Async::Ready(None)) => { + self.client_disconnected(); + break; + } + Ok(Async::NotReady) => break, + Err(ParseError::Io(e)) => { + self.client_disconnected(); + self.error = Some(DispatchError::Io(e)); + break; + } + Err(e) => { + if let Some(mut payload) = self.payload.take() { + payload.set_error(PayloadError::EncodingCorrupted); + } + + // Malformed requests should be responded with 400 + self.messages.push_back(DispatcherMessage::Error( + Response::BadRequest().finish().drop_body(), + )); + self.flags.insert(Flags::DISCONNECTED); + self.error = Some(e.into()); + break; + } + } + } + + if self.ka_timer.is_some() && updated { + if let Some(expire) = self.config.keep_alive_expire() { + self.ka_expire = expire; + } + } + Ok(updated) + } + + /// keep-alive timer + fn poll_keepalive(&mut self) -> Result<(), DispatchError> { + if self.ka_timer.is_none() { + // shutdown timeout + if self.flags.contains(Flags::SHUTDOWN) { + if let Some(interval) = self.config.client_disconnect_timer() { + self.ka_timer = Some(Delay::new(interval)); + } else { + self.flags.insert(Flags::DISCONNECTED); + return Ok(()); + } + } else { + return Ok(()); + } + } + + match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { + error!("Timer error {:?}", e); + DispatchError::Unknown + })? { + Async::Ready(_) => { + // if we get timeout during shutdown, drop connection + if self.flags.contains(Flags::SHUTDOWN) { + return Err(DispatchError::DisconnectTimeout); + } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { + // check for any outstanding tasks + if self.state.is_empty() && self.framed.is_write_buf_empty() { + if self.flags.contains(Flags::STARTED) { + trace!("Keep-alive timeout, close connection"); + self.flags.insert(Flags::SHUTDOWN); + + // start shutdown timer + if let Some(deadline) = self.config.client_disconnect_timer() + { + if let Some(timer) = self.ka_timer.as_mut() { + timer.reset(deadline); + let _ = timer.poll(); + } + } else { + // no shutdown timeout, drop socket + self.flags.insert(Flags::DISCONNECTED); + return Ok(()); + } + } else { + // timeout on first request (slow request) return 408 + if !self.flags.contains(Flags::STARTED) { + trace!("Slow request timeout"); + let _ = self.send_response( + Response::RequestTimeout().finish().drop_body(), + ResponseBody::Other(Body::Empty), + ); + } else { + trace!("Keep-alive connection timeout"); + } + self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); + self.state = State::None; + } + } else if let Some(deadline) = self.config.keep_alive_expire() { + if let Some(timer) = self.ka_timer.as_mut() { + timer.reset(deadline); + let _ = timer.poll(); + } + } + } else if let Some(timer) = self.ka_timer.as_mut() { + timer.reset(self.ka_expire); + let _ = timer.poll(); + } + } + Async::NotReady => (), + } + + Ok(()) + } +} + +impl Future for Dispatcher +where + T: AsyncRead + AsyncWrite, + S: Service, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + type Item = (); + type Error = DispatchError; + + #[inline] + fn poll(&mut self) -> Poll { + let inner = self.inner.as_mut().unwrap(); + + if inner.flags.contains(Flags::SHUTDOWN) { + inner.poll_keepalive()?; + if inner.flags.contains(Flags::DISCONNECTED) { + Ok(Async::Ready(())) + } else { + // try_ready!(inner.poll_flush()); + match inner.framed.get_mut().shutdown()? { + Async::Ready(_) => Ok(Async::Ready(())), + Async::NotReady => Ok(Async::NotReady), + } + } + } else { + inner.poll_keepalive()?; + inner.poll_request()?; + loop { + inner.poll_response()?; + if let Async::Ready(false) = inner.poll_flush()? { + break; + } + } + + if inner.flags.contains(Flags::DISCONNECTED) { + return Ok(Async::Ready(())); + } + + // keep-alive and stream errors + if inner.state.is_empty() && inner.framed.is_write_buf_empty() { + if let Some(err) = inner.error.take() { + return Err(err); + } + // disconnect if keep-alive is not enabled + else if inner.flags.contains(Flags::STARTED) + && !inner.flags.intersects(Flags::KEEPALIVE) + { + inner.flags.insert(Flags::SHUTDOWN); + self.poll() + } + // disconnect if shutdown + else if inner.flags.contains(Flags::SHUTDOWN) { + self.poll() + } else { + return Ok(Async::NotReady); + } + } else { + return Ok(Async::NotReady); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{cmp, io}; + + use actix_codec::{AsyncRead, AsyncWrite}; + use actix_service::IntoService; + use bytes::{Buf, Bytes}; + use futures::future::{lazy, ok}; + + use super::*; + use crate::error::Error; + + struct Buffer { + buf: Bytes, + err: Option, + } + + impl Buffer { + fn new(data: &'static str) -> Buffer { + Buffer { + buf: Bytes::from(data), + err: None, + } + } + } + + impl AsyncRead for Buffer {} + impl io::Read for Buffer { + fn read(&mut self, dst: &mut [u8]) -> Result { + if self.buf.is_empty() { + if self.err.is_some() { + Err(self.err.take().unwrap()) + } else { + Err(io::Error::new(io::ErrorKind::WouldBlock, "")) + } + } else { + let size = cmp::min(self.buf.len(), dst.len()); + let b = self.buf.split_to(size); + dst[..size].copy_from_slice(&b); + Ok(size) + } + } + } + + impl io::Write for Buffer { + fn write(&mut self, buf: &[u8]) -> io::Result { + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + impl AsyncWrite for Buffer { + fn shutdown(&mut self) -> Poll<(), io::Error> { + Ok(Async::Ready(())) + } + fn write_buf(&mut self, _: &mut B) -> Poll { + Ok(Async::NotReady) + } + } + + #[test] + fn test_req_parse_err() { + let mut sys = actix_rt::System::new("test"); + let _ = sys.block_on(lazy(|| { + let buf = Buffer::new("GET /test HTTP/1\r\n\r\n"); + + let mut h1 = Dispatcher::new( + buf, + ServiceConfig::default(), + CloneableService::new( + (|_| ok::<_, Error>(Response::Ok().finish())).into_service(), + ), + ); + assert!(h1.poll().is_ok()); + assert!(h1.poll().is_ok()); + assert!(h1 + .inner + .as_ref() + .unwrap() + .flags + .contains(Flags::DISCONNECTED)); + // assert_eq!(h1.tasks.len(), 1); + ok::<_, ()>(()) + })); + } +} diff --git a/actix-http/src/h1/encoder.rs b/actix-http/src/h1/encoder.rs new file mode 100644 index 000000000..712d123eb --- /dev/null +++ b/actix-http/src/h1/encoder.rs @@ -0,0 +1,421 @@ +#![allow(unused_imports, unused_variables, dead_code)] +use std::fmt::Write as FmtWrite; +use std::io::Write; +use std::marker::PhantomData; +use std::str::FromStr; +use std::{cmp, fmt, io, mem}; + +use bytes::{BufMut, Bytes, BytesMut}; +use http::header::{ + HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, +}; +use http::{HeaderMap, Method, StatusCode, Version}; + +use crate::body::BodyLength; +use crate::config::ServiceConfig; +use crate::header::ContentEncoding; +use crate::helpers; +use crate::message::{ConnectionType, RequestHead, ResponseHead}; +use crate::request::Request; +use crate::response::Response; + +const AVERAGE_HEADER_SIZE: usize = 30; + +#[derive(Debug)] +pub(crate) struct MessageEncoder { + pub length: BodyLength, + pub te: TransferEncoding, + _t: PhantomData, +} + +impl Default for MessageEncoder { + fn default() -> Self { + MessageEncoder { + length: BodyLength::None, + te: TransferEncoding::empty(), + _t: PhantomData, + } + } +} + +pub(crate) trait MessageType: Sized { + fn status(&self) -> Option; + + fn connection_type(&self) -> Option; + + fn headers(&self) -> &HeaderMap; + + fn chunked(&self) -> bool; + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>; + + fn encode_headers( + &mut self, + dst: &mut BytesMut, + version: Version, + mut length: BodyLength, + ctype: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + let chunked = self.chunked(); + let mut skip_len = length != BodyLength::Stream; + + // Content length + if let Some(status) = self.status() { + match status { + StatusCode::NO_CONTENT + | StatusCode::CONTINUE + | StatusCode::PROCESSING => length = BodyLength::None, + StatusCode::SWITCHING_PROTOCOLS => { + skip_len = true; + length = BodyLength::Stream; + } + _ => (), + } + } + match length { + BodyLength::Stream => { + if chunked { + dst.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") + } else { + skip_len = false; + dst.extend_from_slice(b"\r\n"); + } + } + BodyLength::Empty => { + dst.extend_from_slice(b"\r\ncontent-length: 0\r\n"); + } + BodyLength::Sized(len) => helpers::write_content_length(len, dst), + BodyLength::Sized64(len) => { + dst.extend_from_slice(b"\r\ncontent-length: "); + write!(dst.writer(), "{}", len)?; + dst.extend_from_slice(b"\r\n"); + } + BodyLength::None => dst.extend_from_slice(b"\r\n"), + } + + // Connection + match ctype { + ConnectionType::Upgrade => dst.extend_from_slice(b"connection: upgrade\r\n"), + ConnectionType::KeepAlive if version < Version::HTTP_11 => { + dst.extend_from_slice(b"connection: keep-alive\r\n") + } + ConnectionType::Close if version >= Version::HTTP_11 => { + dst.extend_from_slice(b"connection: close\r\n") + } + _ => (), + } + + // write headers + let mut pos = 0; + let mut has_date = false; + let mut remaining = dst.remaining_mut(); + let mut buf = unsafe { &mut *(dst.bytes_mut() as *mut [u8]) }; + for (key, value) in self.headers() { + match *key { + CONNECTION => continue, + TRANSFER_ENCODING | CONTENT_LENGTH if skip_len => continue, + DATE => { + has_date = true; + } + _ => (), + } + + let v = value.as_ref(); + let k = key.as_str().as_bytes(); + let len = k.len() + v.len() + 4; + if len > remaining { + unsafe { + dst.advance_mut(pos); + } + pos = 0; + dst.reserve(len); + remaining = dst.remaining_mut(); + unsafe { + buf = &mut *(dst.bytes_mut() as *mut _); + } + } + + buf[pos..pos + k.len()].copy_from_slice(k); + pos += k.len(); + buf[pos..pos + 2].copy_from_slice(b": "); + pos += 2; + buf[pos..pos + v.len()].copy_from_slice(v); + pos += v.len(); + buf[pos..pos + 2].copy_from_slice(b"\r\n"); + pos += 2; + remaining -= len; + } + unsafe { + dst.advance_mut(pos); + } + + // optimized date header, set_date writes \r\n + if !has_date { + config.set_date(dst); + } else { + // msg eof + dst.extend_from_slice(b"\r\n"); + } + + Ok(()) + } +} + +impl MessageType for Response<()> { + fn status(&self) -> Option { + Some(self.head().status) + } + + fn chunked(&self) -> bool { + !self.head().no_chunking + } + + fn connection_type(&self) -> Option { + self.head().ctype + } + + fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + let head = self.head(); + let reason = head.reason().as_bytes(); + dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); + + // status line + helpers::write_status_line(head.version, head.status.as_u16(), dst); + dst.extend_from_slice(reason); + Ok(()) + } +} + +impl MessageType for RequestHead { + fn status(&self) -> Option { + None + } + + fn connection_type(&self) -> Option { + self.ctype + } + + fn chunked(&self) -> bool { + !self.no_chunking + } + + fn headers(&self) -> &HeaderMap { + &self.headers + } + + fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + write!( + Writer(dst), + "{} {} {}", + self.method, + self.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), + match self.version { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2.0", + } + ) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } +} + +impl MessageEncoder { + /// Encode message + pub fn encode_chunk(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + self.te.encode(msg, buf) + } + + /// Encode eof + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { + self.te.encode_eof(buf) + } + + pub fn encode( + &mut self, + dst: &mut BytesMut, + message: &mut T, + head: bool, + stream: bool, + version: Version, + length: BodyLength, + ctype: ConnectionType, + config: &ServiceConfig, + ) -> io::Result<()> { + // transfer encoding + if !head { + self.te = match length { + BodyLength::Empty => TransferEncoding::empty(), + BodyLength::Sized(len) => TransferEncoding::length(len as u64), + BodyLength::Sized64(len) => TransferEncoding::length(len), + BodyLength::Stream => { + if message.chunked() && !stream { + TransferEncoding::chunked() + } else { + TransferEncoding::eof() + } + } + BodyLength::None => TransferEncoding::empty(), + }; + } else { + self.te = TransferEncoding::empty(); + } + + message.encode_status(dst)?; + message.encode_headers(dst, version, length, ctype, config) + } +} + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug)] +pub(crate) struct TransferEncoding { + kind: TransferEncodingKind, +} + +#[derive(Debug, PartialEq, Clone)] +enum TransferEncodingKind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked(bool), + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when Content-Length is not known. + /// + /// Application decides when to stop writing. + Eof, +} + +impl TransferEncoding { + #[inline] + pub fn empty() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Length(0), + } + } + + #[inline] + pub fn eof() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Eof, + } + } + + #[inline] + pub fn chunked() -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Chunked(false), + } + } + + #[inline] + pub fn length(len: u64) -> TransferEncoding { + TransferEncoding { + kind: TransferEncodingKind::Length(len), + } + } + + /// Encode message. Return `EOF` state of encoder + #[inline] + pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + match self.kind { + TransferEncodingKind::Eof => { + let eof = msg.is_empty(); + buf.extend_from_slice(msg); + Ok(eof) + } + TransferEncodingKind::Chunked(ref mut eof) => { + if *eof { + return Ok(true); + } + + if msg.is_empty() { + *eof = true; + buf.extend_from_slice(b"0\r\n\r\n"); + } else { + writeln!(Writer(buf), "{:X}\r", msg.len()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + buf.reserve(msg.len() + 2); + buf.extend_from_slice(msg); + buf.extend_from_slice(b"\r\n"); + } + Ok(*eof) + } + TransferEncodingKind::Length(ref mut remaining) => { + if *remaining > 0 { + if msg.is_empty() { + return Ok(*remaining == 0); + } + let len = cmp::min(*remaining, msg.len() as u64); + + buf.extend_from_slice(&msg[..len as usize]); + + *remaining -= len as u64; + Ok(*remaining == 0) + } else { + Ok(true) + } + } + } + } + + /// Encode eof. Return `EOF` state of encoder + #[inline] + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { + match self.kind { + TransferEncodingKind::Eof => Ok(()), + TransferEncodingKind::Length(rem) => { + if rem != 0 { + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "")) + } else { + Ok(()) + } + } + TransferEncodingKind::Chunked(ref mut eof) => { + if !*eof { + *eof = true; + buf.extend_from_slice(b"0\r\n\r\n"); + } + Ok(()) + } + } + } +} + +struct Writer<'a>(pub &'a mut BytesMut); + +impl<'a> io::Write for Writer<'a> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[test] + fn test_chunked_te() { + let mut bytes = BytesMut::new(); + let mut enc = TransferEncoding::chunked(); + { + assert!(!enc.encode(b"test", &mut bytes).ok().unwrap()); + assert!(enc.encode(b"", &mut bytes).ok().unwrap()); + } + assert_eq!( + bytes.take().freeze(), + Bytes::from_static(b"4\r\ntest\r\n0\r\n\r\n") + ); + } +} diff --git a/actix-http/src/h1/mod.rs b/actix-http/src/h1/mod.rs new file mode 100644 index 000000000..e3d63c521 --- /dev/null +++ b/actix-http/src/h1/mod.rs @@ -0,0 +1,69 @@ +//! HTTP/1 implementation +use bytes::Bytes; + +mod client; +mod codec; +mod decoder; +mod dispatcher; +mod encoder; +mod payload; +mod service; + +pub use self::client::{ClientCodec, ClientPayloadCodec}; +pub use self::codec::Codec; +pub use self::dispatcher::Dispatcher; +pub use self::payload::{Payload, PayloadBuffer}; +pub use self::service::{H1Service, H1ServiceHandler, OneRequest}; + +#[derive(Debug)] +/// Codec message +pub enum Message { + /// Http message + Item(T), + /// Payload chunk + Chunk(Option), +} + +impl From for Message { + fn from(item: T) -> Self { + Message::Item(item) + } +} + +/// Incoming request type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MessageType { + None, + Payload, + Stream, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::request::Request; + + impl Message { + pub fn message(self) -> Request { + match self { + Message::Item(req) => req, + _ => panic!("error"), + } + } + + pub fn chunk(self) -> Bytes { + match self { + Message::Chunk(Some(data)) => data, + _ => panic!("error"), + } + } + + pub fn eof(self) -> bool { + match self { + Message::Chunk(None) => true, + Message::Chunk(Some(_)) => false, + _ => panic!("error"), + } + } + } +} diff --git a/actix-http/src/h1/payload.rs b/actix-http/src/h1/payload.rs new file mode 100644 index 000000000..6665a0e41 --- /dev/null +++ b/actix-http/src/h1/payload.rs @@ -0,0 +1,716 @@ +//! Payload stream +use bytes::{Bytes, BytesMut}; +#[cfg(not(test))] +use futures::task::current as current_task; +use futures::task::Task; +use futures::{Async, Poll, Stream}; +use std::cell::RefCell; +use std::cmp; +use std::collections::VecDeque; +use std::rc::{Rc, Weak}; + +use crate::error::PayloadError; + +/// max buffer size 32k +pub(crate) const MAX_BUFFER_SIZE: usize = 32_768; + +#[derive(Debug, PartialEq)] +pub(crate) enum PayloadStatus { + Read, + Pause, + Dropped, +} + +/// Buffered stream of bytes chunks +/// +/// Payload stores chunks in a vector. First chunk can be received with +/// `.readany()` method. Payload stream is not thread safe. Payload does not +/// notify current task when new data is available. +/// +/// Payload stream can be used as `Response` body stream. +#[derive(Debug)] +pub struct Payload { + inner: Rc>, +} + +impl Payload { + /// Create payload stream. + /// + /// This method construct two objects responsible for bytes stream + /// generation. + /// + /// * `PayloadSender` - *Sender* side of the stream + /// + /// * `Payload` - *Receiver* side of the stream + pub fn create(eof: bool) -> (PayloadSender, Payload) { + let shared = Rc::new(RefCell::new(Inner::new(eof))); + + ( + PayloadSender { + inner: Rc::downgrade(&shared), + }, + Payload { inner: shared }, + ) + } + + /// Create empty payload + #[doc(hidden)] + pub fn empty() -> Payload { + Payload { + inner: Rc::new(RefCell::new(Inner::new(true))), + } + } + + /// Length of the data in this payload + #[cfg(test)] + pub fn len(&self) -> usize { + self.inner.borrow().len() + } + + /// Is payload empty + #[cfg(test)] + pub fn is_empty(&self) -> bool { + self.inner.borrow().len() == 0 + } + + /// Put unused data back to payload + #[inline] + pub fn unread_data(&mut self, data: Bytes) { + self.inner.borrow_mut().unread_data(data); + } + + #[cfg(test)] + pub(crate) fn readall(&self) -> Option { + self.inner.borrow_mut().readall() + } + + #[inline] + /// Set read buffer capacity + /// + /// Default buffer capacity is 32Kb. + pub fn set_read_buffer_capacity(&mut self, cap: usize) { + self.inner.borrow_mut().capacity = cap; + } +} + +impl Stream for Payload { + type Item = Bytes; + type Error = PayloadError; + + #[inline] + fn poll(&mut self) -> Poll, PayloadError> { + self.inner.borrow_mut().readany() + } +} + +impl Clone for Payload { + fn clone(&self) -> Payload { + Payload { + inner: Rc::clone(&self.inner), + } + } +} + +/// Payload writer interface. +pub(crate) trait PayloadWriter { + /// Set stream error. + fn set_error(&mut self, err: PayloadError); + + /// Write eof into a stream which closes reading side of a stream. + fn feed_eof(&mut self); + + /// Feed bytes into a payload stream + fn feed_data(&mut self, data: Bytes); + + /// Need read data + fn need_read(&self) -> PayloadStatus; +} + +/// Sender part of the payload stream +pub struct PayloadSender { + inner: Weak>, +} + +impl PayloadWriter for PayloadSender { + #[inline] + fn set_error(&mut self, err: PayloadError) { + if let Some(shared) = self.inner.upgrade() { + shared.borrow_mut().set_error(err) + } + } + + #[inline] + fn feed_eof(&mut self) { + if let Some(shared) = self.inner.upgrade() { + shared.borrow_mut().feed_eof() + } + } + + #[inline] + fn feed_data(&mut self, data: Bytes) { + if let Some(shared) = self.inner.upgrade() { + shared.borrow_mut().feed_data(data) + } + } + + #[inline] + fn need_read(&self) -> PayloadStatus { + // we check need_read only if Payload (other side) is alive, + // otherwise always return true (consume payload) + if let Some(shared) = self.inner.upgrade() { + if shared.borrow().need_read { + PayloadStatus::Read + } else { + #[cfg(not(test))] + { + if shared.borrow_mut().io_task.is_none() { + shared.borrow_mut().io_task = Some(current_task()); + } + } + PayloadStatus::Pause + } + } else { + PayloadStatus::Dropped + } + } +} + +#[derive(Debug)] +struct Inner { + len: usize, + eof: bool, + err: Option, + need_read: bool, + items: VecDeque, + capacity: usize, + task: Option, + io_task: Option, +} + +impl Inner { + fn new(eof: bool) -> Self { + Inner { + eof, + len: 0, + err: None, + items: VecDeque::new(), + need_read: true, + capacity: MAX_BUFFER_SIZE, + task: None, + io_task: None, + } + } + + #[inline] + fn set_error(&mut self, err: PayloadError) { + self.err = Some(err); + } + + #[inline] + fn feed_eof(&mut self) { + self.eof = true; + } + + #[inline] + fn feed_data(&mut self, data: Bytes) { + self.len += data.len(); + self.items.push_back(data); + self.need_read = self.len < self.capacity; + if let Some(task) = self.task.take() { + task.notify() + } + } + + #[cfg(test)] + fn len(&self) -> usize { + self.len + } + + #[cfg(test)] + pub(crate) fn readall(&mut self) -> Option { + let len = self.items.iter().map(|b| b.len()).sum(); + if len > 0 { + let mut buf = BytesMut::with_capacity(len); + for item in &self.items { + buf.extend_from_slice(item); + } + self.items = VecDeque::new(); + self.len = 0; + Some(buf.take().freeze()) + } else { + self.need_read = true; + None + } + } + + fn readany(&mut self) -> Poll, PayloadError> { + if let Some(data) = self.items.pop_front() { + self.len -= data.len(); + self.need_read = self.len < self.capacity; + #[cfg(not(test))] + { + if self.need_read && self.task.is_none() { + self.task = Some(current_task()); + } + if let Some(task) = self.io_task.take() { + task.notify() + } + } + Ok(Async::Ready(Some(data))) + } else if let Some(err) = self.err.take() { + Err(err) + } else if self.eof { + Ok(Async::Ready(None)) + } else { + self.need_read = true; + #[cfg(not(test))] + { + if self.task.is_none() { + self.task = Some(current_task()); + } + if let Some(task) = self.io_task.take() { + task.notify() + } + } + Ok(Async::NotReady) + } + } + + fn unread_data(&mut self, data: Bytes) { + self.len += data.len(); + self.items.push_front(data); + } +} + +/// Payload buffer +pub struct PayloadBuffer { + len: usize, + items: VecDeque, + stream: S, +} + +impl PayloadBuffer +where + S: Stream, +{ + /// Create new `PayloadBuffer` instance + pub fn new(stream: S) -> Self { + PayloadBuffer { + len: 0, + items: VecDeque::new(), + stream, + } + } + + /// Get mutable reference to an inner stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } + + #[inline] + fn poll_stream(&mut self) -> Poll { + self.stream.poll().map(|res| match res { + Async::Ready(Some(data)) => { + self.len += data.len(); + self.items.push_back(data); + Async::Ready(true) + } + Async::Ready(None) => Async::Ready(false), + Async::NotReady => Async::NotReady, + }) + } + + /// Read first available chunk of bytes + #[inline] + pub fn readany(&mut self) -> Poll, PayloadError> { + if let Some(data) = self.items.pop_front() { + self.len -= data.len(); + Ok(Async::Ready(Some(data))) + } else { + match self.poll_stream()? { + Async::Ready(true) => self.readany(), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + } + + /// Check if buffer contains enough bytes + #[inline] + pub fn can_read(&mut self, size: usize) -> Poll, PayloadError> { + if size <= self.len { + Ok(Async::Ready(Some(true))) + } else { + match self.poll_stream()? { + Async::Ready(true) => self.can_read(size), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + } + + /// Return reference to the first chunk of data + #[inline] + pub fn get_chunk(&mut self) -> Poll, PayloadError> { + if self.items.is_empty() { + match self.poll_stream()? { + Async::Ready(true) => (), + Async::Ready(false) => return Ok(Async::Ready(None)), + Async::NotReady => return Ok(Async::NotReady), + } + } + match self.items.front().map(|c| c.as_ref()) { + Some(chunk) => Ok(Async::Ready(Some(chunk))), + None => Ok(Async::NotReady), + } + } + + /// Read exact number of bytes + #[inline] + pub fn read_exact(&mut self, size: usize) -> Poll, PayloadError> { + if size <= self.len { + self.len -= size; + let mut chunk = self.items.pop_front().unwrap(); + if size < chunk.len() { + let buf = chunk.split_to(size); + self.items.push_front(chunk); + Ok(Async::Ready(Some(buf))) + } else if size == chunk.len() { + Ok(Async::Ready(Some(chunk))) + } else { + let mut buf = BytesMut::with_capacity(size); + buf.extend_from_slice(&chunk); + + while buf.len() < size { + let mut chunk = self.items.pop_front().unwrap(); + let rem = cmp::min(size - buf.len(), chunk.len()); + buf.extend_from_slice(&chunk.split_to(rem)); + if !chunk.is_empty() { + self.items.push_front(chunk); + } + } + Ok(Async::Ready(Some(buf.freeze()))) + } + } else { + match self.poll_stream()? { + Async::Ready(true) => self.read_exact(size), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + } + + /// Remove specified amount if bytes from buffer + #[inline] + pub fn drop_bytes(&mut self, size: usize) { + if size <= self.len { + self.len -= size; + + let mut len = 0; + while len < size { + let mut chunk = self.items.pop_front().unwrap(); + let rem = cmp::min(size - len, chunk.len()); + len += rem; + if rem < chunk.len() { + chunk.split_to(rem); + self.items.push_front(chunk); + } + } + } + } + + /// Copy buffered data + pub fn copy(&mut self, size: usize) -> Poll, PayloadError> { + if size <= self.len { + let mut buf = BytesMut::with_capacity(size); + for chunk in &self.items { + if buf.len() < size { + let rem = cmp::min(size - buf.len(), chunk.len()); + buf.extend_from_slice(&chunk[..rem]); + } + if buf.len() == size { + return Ok(Async::Ready(Some(buf))); + } + } + } + + match self.poll_stream()? { + Async::Ready(true) => self.copy(size), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + + /// Read until specified ending + pub fn read_until(&mut self, line: &[u8]) -> Poll, PayloadError> { + let mut idx = 0; + let mut num = 0; + let mut offset = 0; + let mut found = false; + let mut length = 0; + + for no in 0..self.items.len() { + { + let chunk = &self.items[no]; + for (pos, ch) in chunk.iter().enumerate() { + if *ch == line[idx] { + idx += 1; + if idx == line.len() { + num = no; + offset = pos + 1; + length += pos + 1; + found = true; + break; + } + } else { + idx = 0 + } + } + if !found { + length += chunk.len() + } + } + + if found { + let mut buf = BytesMut::with_capacity(length); + if num > 0 { + for _ in 0..num { + buf.extend_from_slice(&self.items.pop_front().unwrap()); + } + } + if offset > 0 { + let mut chunk = self.items.pop_front().unwrap(); + buf.extend_from_slice(&chunk.split_to(offset)); + if !chunk.is_empty() { + self.items.push_front(chunk) + } + } + self.len -= length; + return Ok(Async::Ready(Some(buf.freeze()))); + } + } + + match self.poll_stream()? { + Async::Ready(true) => self.read_until(line), + Async::Ready(false) => Ok(Async::Ready(None)), + Async::NotReady => Ok(Async::NotReady), + } + } + + /// Read bytes until new line delimiter + pub fn readline(&mut self) -> Poll, PayloadError> { + self.read_until(b"\n") + } + + /// Put unprocessed data back to the buffer + pub fn unprocessed(&mut self, data: Bytes) { + self.len += data.len(); + self.items.push_front(data); + } + + /// Get remaining data from the buffer + pub fn remaining(&mut self) -> Bytes { + self.items + .iter_mut() + .fold(BytesMut::new(), |mut b, c| { + b.extend_from_slice(c); + b + }) + .freeze() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use actix_rt::Runtime; + use futures::future::{lazy, result}; + + #[test] + fn test_error() { + let err = PayloadError::Incomplete(None); + assert_eq!( + format!("{}", err), + "A payload reached EOF, but is not complete." + ); + } + + #[test] + fn test_basic() { + Runtime::new() + .unwrap() + .block_on(lazy(|| { + let (_, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(payload.len, 0); + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); + + let res: Result<(), ()> = Ok(()); + result(res) + })) + .unwrap(); + } + + #[test] + fn test_eof() { + Runtime::new() + .unwrap() + .block_on(lazy(|| { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); + sender.feed_data(Bytes::from("data")); + sender.feed_eof(); + + assert_eq!( + Async::Ready(Some(Bytes::from("data"))), + payload.readany().ok().unwrap() + ); + assert_eq!(payload.len, 0); + assert_eq!(Async::Ready(None), payload.readany().ok().unwrap()); + + let res: Result<(), ()> = Ok(()); + result(res) + })) + .unwrap(); + } + + #[test] + fn test_err() { + Runtime::new() + .unwrap() + .block_on(lazy(|| { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); + + sender.set_error(PayloadError::Incomplete(None)); + payload.readany().err().unwrap(); + let res: Result<(), ()> = Ok(()); + result(res) + })) + .unwrap(); + } + + #[test] + fn test_readany() { + Runtime::new() + .unwrap() + .block_on(lazy(|| { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + + assert_eq!( + Async::Ready(Some(Bytes::from("line1"))), + payload.readany().ok().unwrap() + ); + assert_eq!(payload.len, 0); + + assert_eq!( + Async::Ready(Some(Bytes::from("line2"))), + payload.readany().ok().unwrap() + ); + assert_eq!(payload.len, 0); + + let res: Result<(), ()> = Ok(()); + result(res) + })) + .unwrap(); + } + + #[test] + fn test_readexactly() { + Runtime::new() + .unwrap() + .block_on(lazy(|| { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.read_exact(2).ok().unwrap()); + + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + + assert_eq!( + Async::Ready(Some(Bytes::from_static(b"li"))), + payload.read_exact(2).ok().unwrap() + ); + assert_eq!(payload.len, 3); + + assert_eq!( + Async::Ready(Some(Bytes::from_static(b"ne1l"))), + payload.read_exact(4).ok().unwrap() + ); + assert_eq!(payload.len, 4); + + sender.set_error(PayloadError::Incomplete(None)); + payload.read_exact(10).err().unwrap(); + + let res: Result<(), ()> = Ok(()); + result(res) + })) + .unwrap(); + } + + #[test] + fn test_readuntil() { + Runtime::new() + .unwrap() + .block_on(lazy(|| { + let (mut sender, payload) = Payload::create(false); + let mut payload = PayloadBuffer::new(payload); + + assert_eq!(Async::NotReady, payload.read_until(b"ne").ok().unwrap()); + + sender.feed_data(Bytes::from("line1")); + sender.feed_data(Bytes::from("line2")); + + assert_eq!( + Async::Ready(Some(Bytes::from("line"))), + payload.read_until(b"ne").ok().unwrap() + ); + assert_eq!(payload.len, 1); + + assert_eq!( + Async::Ready(Some(Bytes::from("1line2"))), + payload.read_until(b"2").ok().unwrap() + ); + assert_eq!(payload.len, 0); + + sender.set_error(PayloadError::Incomplete(None)); + payload.read_until(b"b").err().unwrap(); + + let res: Result<(), ()> = Ok(()); + result(res) + })) + .unwrap(); + } + + #[test] + fn test_unread_data() { + Runtime::new() + .unwrap() + .block_on(lazy(|| { + let (_, mut payload) = Payload::create(false); + + payload.unread_data(Bytes::from("data")); + assert!(!payload.is_empty()); + assert_eq!(payload.len(), 4); + + assert_eq!( + Async::Ready(Some(Bytes::from("data"))), + payload.poll().ok().unwrap() + ); + + let res: Result<(), ()> = Ok(()); + result(res) + })) + .unwrap(); + } +} diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs new file mode 100644 index 000000000..f3301b9b2 --- /dev/null +++ b/actix-http/src/h1/service.rs @@ -0,0 +1,257 @@ +use std::fmt::Debug; +use std::marker::PhantomData; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_server_config::{Io, ServerConfig as SrvConfig}; +use actix_service::{IntoNewService, NewService, Service}; +use actix_utils::cloneable::CloneableService; +use futures::future::{ok, FutureResult}; +use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; + +use crate::body::MessageBody; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::{DispatchError, ParseError}; +use crate::request::Request; +use crate::response::Response; + +use super::codec::Codec; +use super::dispatcher::Dispatcher; +use super::Message; + +/// `NewService` implementation for HTTP1 transport +pub struct H1Service { + srv: S, + cfg: ServiceConfig, + _t: PhantomData<(T, P, B)>, +} + +impl H1Service +where + S: NewService, + S::Error: Debug, + S::Response: Into>, + S::Service: 'static, + B: MessageBody, +{ + /// Create new `HttpService` instance with default config. + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + + H1Service { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } + + /// Create new `HttpService` instance with config. + pub fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { + H1Service { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } +} + +impl NewService for H1Service +where + T: AsyncRead + AsyncWrite, + S: NewService, + S::Error: Debug, + S::Response: Into>, + S::Service: 'static, + B: MessageBody, +{ + type Request = Io; + type Response = (); + type Error = DispatchError; + type InitError = S::InitError; + type Service = H1ServiceHandler; + type Future = H1ServiceResponse; + + fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + H1ServiceResponse { + fut: self.srv.new_service(cfg).into_future(), + cfg: Some(self.cfg.clone()), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +pub struct H1ServiceResponse, B> { + fut: ::Future, + cfg: Option, + _t: PhantomData<(T, P, B)>, +} + +impl Future for H1ServiceResponse +where + T: AsyncRead + AsyncWrite, + S: NewService, + S::Service: 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + type Item = H1ServiceHandler; + type Error = S::InitError; + + fn poll(&mut self) -> Poll { + let service = try_ready!(self.fut.poll()); + Ok(Async::Ready(H1ServiceHandler::new( + self.cfg.take().unwrap(), + service, + ))) + } +} + +/// `Service` implementation for HTTP1 transport +pub struct H1ServiceHandler { + srv: CloneableService, + cfg: ServiceConfig, + _t: PhantomData<(T, P, B)>, +} + +impl H1ServiceHandler +where + S: Service, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + fn new(cfg: ServiceConfig, srv: S) -> H1ServiceHandler { + H1ServiceHandler { + srv: CloneableService::new(srv), + cfg, + _t: PhantomData, + } + } +} + +impl Service for H1ServiceHandler +where + T: AsyncRead + AsyncWrite, + S: Service, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + type Request = Io; + type Response = (); + type Error = DispatchError; + type Future = Dispatcher; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.srv.poll_ready().map_err(|e| { + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service + }) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + Dispatcher::new(req.into_parts().0, self.cfg.clone(), self.srv.clone()) + } +} + +/// `NewService` implementation for `OneRequestService` service +#[derive(Default)] +pub struct OneRequest { + config: ServiceConfig, + _t: PhantomData<(T, P)>, +} + +impl OneRequest +where + T: AsyncRead + AsyncWrite, +{ + /// Create new `H1SimpleService` instance. + pub fn new() -> Self { + OneRequest { + config: ServiceConfig::default(), + _t: PhantomData, + } + } +} + +impl NewService for OneRequest +where + T: AsyncRead + AsyncWrite, +{ + type Request = Io; + type Response = (Request, Framed); + type Error = ParseError; + type InitError = (); + type Service = OneRequestService; + type Future = FutureResult; + + fn new_service(&self, _: &SrvConfig) -> Self::Future { + ok(OneRequestService { + config: self.config.clone(), + _t: PhantomData, + }) + } +} + +/// `Service` implementation for HTTP1 transport. Reads one request and returns +/// request and framed object. +pub struct OneRequestService { + config: ServiceConfig, + _t: PhantomData<(T, P)>, +} + +impl Service for OneRequestService +where + T: AsyncRead + AsyncWrite, +{ + type Request = Io; + type Response = (Request, Framed); + type Error = ParseError; + type Future = OneRequestServiceResponse; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + OneRequestServiceResponse { + framed: Some(Framed::new( + req.into_parts().0, + Codec::new(self.config.clone()), + )), + } + } +} + +#[doc(hidden)] +pub struct OneRequestServiceResponse +where + T: AsyncRead + AsyncWrite, +{ + framed: Option>, +} + +impl Future for OneRequestServiceResponse +where + T: AsyncRead + AsyncWrite, +{ + type Item = (Request, Framed); + type Error = ParseError; + + fn poll(&mut self) -> Poll { + match self.framed.as_mut().unwrap().poll()? { + Async::Ready(Some(req)) => match req { + Message::Item(req) => { + Ok(Async::Ready((req, self.framed.take().unwrap()))) + } + Message::Chunk(_) => unreachable!("Something is wrong"), + }, + Async::Ready(None) => Err(ParseError::Incomplete), + Async::NotReady => Ok(Async::NotReady), + } + } +} diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs new file mode 100644 index 000000000..ea63dc2bc --- /dev/null +++ b/actix-http/src/h2/dispatcher.rs @@ -0,0 +1,324 @@ +use std::collections::VecDeque; +use std::marker::PhantomData; +use std::time::Instant; +use std::{fmt, mem}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_service::Service; +use actix_utils::cloneable::CloneableService; +use bitflags::bitflags; +use bytes::{Bytes, BytesMut}; +use futures::{try_ready, Async, Future, Poll, Sink, Stream}; +use h2::server::{Connection, SendResponse}; +use h2::{RecvStream, SendStream}; +use http::header::{ + HeaderValue, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, +}; +use http::HttpTryFrom; +use log::{debug, error, trace}; +use tokio_timer::Delay; + +use crate::body::{Body, BodyLength, MessageBody, ResponseBody}; +use crate::config::ServiceConfig; +use crate::error::{DispatchError, Error, ParseError, PayloadError, ResponseError}; +use crate::message::ResponseHead; +use crate::payload::Payload; +use crate::request::Request; +use crate::response::Response; + +const CHUNK_SIZE: usize = 16_384; + +/// Dispatcher for HTTP/2 protocol +pub struct Dispatcher< + T: AsyncRead + AsyncWrite, + S: Service + 'static, + B: MessageBody, +> { + service: CloneableService, + connection: Connection, + config: ServiceConfig, + ka_expire: Instant, + ka_timer: Option, + _t: PhantomData, +} + +impl Dispatcher +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: fmt::Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + pub fn new( + service: CloneableService, + connection: Connection, + config: ServiceConfig, + timeout: Option, + ) -> Self { + // let keepalive = config.keep_alive_enabled(); + // let flags = if keepalive { + // Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED + // } else { + // Flags::empty() + // }; + + // keep-alive timer + let (ka_expire, ka_timer) = if let Some(delay) = timeout { + (delay.deadline(), Some(delay)) + } else if let Some(delay) = config.keep_alive_timer() { + (delay.deadline(), Some(delay)) + } else { + (config.now(), None) + }; + + Dispatcher { + service, + config, + ka_expire, + ka_timer, + connection, + _t: PhantomData, + } + } +} + +impl Future for Dispatcher +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: fmt::Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Item = (); + type Error = DispatchError; + + #[inline] + fn poll(&mut self) -> Poll { + loop { + match self.connection.poll()? { + Async::Ready(None) => return Ok(Async::Ready(())), + Async::Ready(Some((req, res))) => { + // update keep-alive expire + if self.ka_timer.is_some() { + if let Some(expire) = self.config.keep_alive_expire() { + self.ka_expire = expire; + } + } + + let (parts, body) = req.into_parts(); + let mut req = Request::with_payload(body.into()); + + let head = &mut req.head_mut(); + head.uri = parts.uri; + head.method = parts.method; + head.version = parts.version; + head.headers = parts.headers; + tokio_current_thread::spawn(ServiceResponse:: { + state: ServiceResponseState::ServiceCall( + self.service.call(req), + Some(res), + ), + config: self.config.clone(), + buffer: None, + }) + } + Async::NotReady => return Ok(Async::NotReady), + } + } + } +} + +struct ServiceResponse { + state: ServiceResponseState, + config: ServiceConfig, + buffer: Option, +} + +enum ServiceResponseState { + ServiceCall(S::Future, Option>), + SendPayload(SendStream, ResponseBody), +} + +impl ServiceResponse +where + S: Service + 'static, + S::Error: fmt::Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + fn prepare_response( + &self, + head: &ResponseHead, + length: &mut BodyLength, + ) -> http::Response<()> { + let mut has_date = false; + let mut skip_len = length != &BodyLength::Stream; + + let mut res = http::Response::new(()); + *res.status_mut() = head.status; + *res.version_mut() = http::Version::HTTP_2; + + // Content length + match head.status { + http::StatusCode::NO_CONTENT + | http::StatusCode::CONTINUE + | http::StatusCode::PROCESSING => *length = BodyLength::None, + http::StatusCode::SWITCHING_PROTOCOLS => { + skip_len = true; + *length = BodyLength::Stream; + } + _ => (), + } + let _ = match length { + BodyLength::None | BodyLength::Stream => None, + BodyLength::Empty => res + .headers_mut() + .insert(CONTENT_LENGTH, HeaderValue::from_static("0")), + BodyLength::Sized(len) => res.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + BodyLength::Sized64(len) => res.headers_mut().insert( + CONTENT_LENGTH, + HeaderValue::try_from(format!("{}", len)).unwrap(), + ), + }; + + // copy headers + for (key, value) in head.headers.iter() { + match *key { + CONNECTION | TRANSFER_ENCODING => continue, // http2 specific + CONTENT_LENGTH if skip_len => continue, + DATE => has_date = true, + _ => (), + } + res.headers_mut().append(key, value.clone()); + } + + // set date header + if !has_date { + let mut bytes = BytesMut::with_capacity(29); + self.config.set_date_header(&mut bytes); + res.headers_mut() + .insert(DATE, HeaderValue::try_from(bytes.freeze()).unwrap()); + } + + res + } +} + +impl Future for ServiceResponse +where + S: Service + 'static, + S::Error: fmt::Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll { + match self.state { + ServiceResponseState::ServiceCall(ref mut call, ref mut send) => { + match call.poll() { + Ok(Async::Ready(res)) => { + let (res, body) = res.into().replace_body(()); + + let mut send = send.take().unwrap(); + let mut length = body.length(); + let h2_res = self.prepare_response(res.head(), &mut length); + + let stream = send + .send_response(h2_res, length.is_eof()) + .map_err(|e| { + trace!("Error sending h2 response: {:?}", e); + })?; + + if length.is_eof() { + Ok(Async::Ready(())) + } else { + self.state = ServiceResponseState::SendPayload(stream, body); + self.poll() + } + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_e) => { + let res: Response = Response::InternalServerError().finish(); + let (res, body) = res.replace_body(()); + + let mut send = send.take().unwrap(); + let mut length = body.length(); + let h2_res = self.prepare_response(res.head(), &mut length); + + let stream = send + .send_response(h2_res, length.is_eof()) + .map_err(|e| { + trace!("Error sending h2 response: {:?}", e); + })?; + + if length.is_eof() { + Ok(Async::Ready(())) + } else { + self.state = ServiceResponseState::SendPayload( + stream, + body.into_body(), + ); + self.poll() + } + } + } + } + ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop { + loop { + if let Some(ref mut buffer) = self.buffer { + match stream.poll_capacity().map_err(|e| warn!("{:?}", e))? { + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(None) => return Ok(Async::Ready(())), + Async::Ready(Some(cap)) => { + let len = buffer.len(); + let bytes = buffer.split_to(std::cmp::min(cap, len)); + + if let Err(e) = stream.send_data(bytes, false) { + warn!("{:?}", e); + return Err(()); + } else if !buffer.is_empty() { + let cap = std::cmp::min(buffer.len(), CHUNK_SIZE); + stream.reserve_capacity(cap); + } else { + self.buffer.take(); + } + } + } + } else { + match body.poll_next() { + Ok(Async::NotReady) => { + return Ok(Async::NotReady); + } + Ok(Async::Ready(None)) => { + if let Err(e) = stream.send_data(Bytes::new(), true) { + warn!("{:?}", e); + return Err(()); + } else { + return Ok(Async::Ready(())); + } + } + Ok(Async::Ready(Some(chunk))) => { + stream.reserve_capacity(std::cmp::min( + chunk.len(), + CHUNK_SIZE, + )); + self.buffer = Some(chunk); + } + Err(e) => { + error!("Response payload stream error: {:?}", e); + return Err(()); + } + } + } + } + }, + } + } +} diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs new file mode 100644 index 000000000..c5972123f --- /dev/null +++ b/actix-http/src/h2/mod.rs @@ -0,0 +1,46 @@ +#![allow(dead_code, unused_imports)] + +use std::fmt; + +use bytes::Bytes; +use futures::{Async, Poll, Stream}; +use h2::RecvStream; + +mod dispatcher; +mod service; + +pub use self::dispatcher::Dispatcher; +pub use self::service::H2Service; +use crate::error::PayloadError; + +/// H2 receive stream +pub struct Payload { + pl: RecvStream, +} + +impl Payload { + pub(crate) fn new(pl: RecvStream) -> Self { + Self { pl } + } +} + +impl Stream for Payload { + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll, Self::Error> { + match self.pl.poll() { + Ok(Async::Ready(Some(chunk))) => { + let len = chunk.len(); + if let Err(err) = self.pl.release_capacity().release_capacity(len) { + Err(err.into()) + } else { + Ok(Async::Ready(Some(chunk))) + } + } + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(err) => Err(err.into()), + } + } +} diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs new file mode 100644 index 000000000..6ab37919c --- /dev/null +++ b/actix-http/src/h2/service.rs @@ -0,0 +1,229 @@ +use std::fmt::Debug; +use std::marker::PhantomData; +use std::{io, net}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_server_config::{Io, ServerConfig as SrvConfig}; +use actix_service::{IntoNewService, NewService, Service}; +use actix_utils::cloneable::CloneableService; +use bytes::Bytes; +use futures::future::{ok, FutureResult}; +use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; +use h2::server::{self, Connection, Handshake}; +use h2::RecvStream; +use log::error; + +use crate::body::MessageBody; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::{DispatchError, Error, ParseError, ResponseError}; +use crate::payload::Payload; +use crate::request::Request; +use crate::response::Response; + +use super::dispatcher::Dispatcher; + +/// `NewService` implementation for HTTP2 transport +pub struct H2Service { + srv: S, + cfg: ServiceConfig, + _t: PhantomData<(T, P, B)>, +} + +impl H2Service +where + S: NewService, + S::Service: 'static, + S::Error: Debug + 'static, + S::Response: Into>, + B: MessageBody + 'static, +{ + /// Create new `HttpService` instance. + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + + H2Service { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } + + /// Create new `HttpService` instance with config. + pub fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { + H2Service { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } +} + +impl NewService for H2Service +where + T: AsyncRead + AsyncWrite, + S: NewService, + S::Service: 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Request = Io; + type Response = (); + type Error = DispatchError; + type InitError = S::InitError; + type Service = H2ServiceHandler; + type Future = H2ServiceResponse; + + fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + H2ServiceResponse { + fut: self.srv.new_service(cfg).into_future(), + cfg: Some(self.cfg.clone()), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +pub struct H2ServiceResponse, B> { + fut: ::Future, + cfg: Option, + _t: PhantomData<(T, P, B)>, +} + +impl Future for H2ServiceResponse +where + T: AsyncRead + AsyncWrite, + S: NewService, + S::Service: 'static, + S::Response: Into>, + S::Error: Debug, + B: MessageBody + 'static, +{ + type Item = H2ServiceHandler; + type Error = S::InitError; + + fn poll(&mut self) -> Poll { + let service = try_ready!(self.fut.poll()); + Ok(Async::Ready(H2ServiceHandler::new( + self.cfg.take().unwrap(), + service, + ))) + } +} + +/// `Service` implementation for http/2 transport +pub struct H2ServiceHandler { + srv: CloneableService, + cfg: ServiceConfig, + _t: PhantomData<(T, P, B)>, +} + +impl H2ServiceHandler +where + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + fn new(cfg: ServiceConfig, srv: S) -> H2ServiceHandler { + H2ServiceHandler { + cfg, + srv: CloneableService::new(srv), + _t: PhantomData, + } + } +} + +impl Service for H2ServiceHandler +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Request = Io; + type Response = (); + type Error = DispatchError; + type Future = H2ServiceHandlerResponse; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.srv.poll_ready().map_err(|e| { + error!("Service readiness error: {:?}", e); + DispatchError::Service + }) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + H2ServiceHandlerResponse { + state: State::Handshake( + Some(self.srv.clone()), + Some(self.cfg.clone()), + server::handshake(req.into_parts().0), + ), + } + } +} + +enum State< + T: AsyncRead + AsyncWrite, + S: Service + 'static, + B: MessageBody, +> { + Incoming(Dispatcher), + Handshake( + Option>, + Option, + Handshake, + ), +} + +pub struct H2ServiceHandlerResponse +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + state: State, +} + +impl Future for H2ServiceHandlerResponse +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + type Item = (); + type Error = DispatchError; + + fn poll(&mut self) -> Poll { + match self.state { + State::Incoming(ref mut disp) => disp.poll(), + State::Handshake(ref mut srv, ref mut config, ref mut handshake) => { + match handshake.poll() { + Ok(Async::Ready(conn)) => { + self.state = State::Incoming(Dispatcher::new( + srv.take().unwrap(), + conn, + config.take().unwrap(), + None, + )); + self.poll() + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(err) => { + trace!("H2 handshake error: {}", err); + return Err(err.into()); + } + } + } + } + } +} diff --git a/actix-http/src/header/common/accept.rs b/actix-http/src/header/common/accept.rs new file mode 100644 index 000000000..d52eba241 --- /dev/null +++ b/actix-http/src/header/common/accept.rs @@ -0,0 +1,160 @@ +use mime::Mime; + +use crate::header::{qitem, QualityItem}; +use crate::http::header; + +header! { + /// `Accept` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.2) + /// + /// The `Accept` header field can be used by user agents to specify + /// response media types that are acceptable. Accept header fields can + /// be used to indicate that the request is specifically limited to a + /// small set of desired types, as in the case of a request for an + /// in-line image + /// + /// # ABNF + /// + /// ```text + /// Accept = #( media-range [ accept-params ] ) + /// + /// media-range = ( "*/*" + /// / ( type "/" "*" ) + /// / ( type "/" subtype ) + /// ) *( OWS ";" OWS parameter ) + /// accept-params = weight *( accept-ext ) + /// accept-ext = OWS ";" OWS token [ "=" ( token / quoted-string ) ] + /// ``` + /// + /// # Example values + /// * `audio/*; q=0.2, audio/basic` + /// * `text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c` + /// + /// # Examples + /// ```rust + /// # extern crate actix_http; + /// extern crate mime; + /// use actix_http::Response; + /// use actix_http::http::header::{Accept, qitem}; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// + /// builder.set( + /// Accept(vec![ + /// qitem(mime::TEXT_HTML), + /// ]) + /// ); + /// # } + /// ``` + /// + /// ```rust + /// # extern crate actix_http; + /// extern crate mime; + /// use actix_http::Response; + /// use actix_http::http::header::{Accept, qitem}; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// + /// builder.set( + /// Accept(vec![ + /// qitem(mime::APPLICATION_JSON), + /// ]) + /// ); + /// # } + /// ``` + /// + /// ```rust + /// # extern crate actix_http; + /// extern crate mime; + /// use actix_http::Response; + /// use actix_http::http::header::{Accept, QualityItem, q, qitem}; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// + /// builder.set( + /// Accept(vec![ + /// qitem(mime::TEXT_HTML), + /// qitem("application/xhtml+xml".parse().unwrap()), + /// QualityItem::new( + /// mime::TEXT_XML, + /// q(900) + /// ), + /// qitem("image/webp".parse().unwrap()), + /// QualityItem::new( + /// mime::STAR_STAR, + /// q(800) + /// ), + /// ]) + /// ); + /// # } + /// ``` + (Accept, header::ACCEPT) => (QualityItem)+ + + test_accept { + // Tests from the RFC + test_header!( + test1, + vec![b"audio/*; q=0.2, audio/basic"], + Some(HeaderField(vec![ + QualityItem::new("audio/*".parse().unwrap(), q(200)), + qitem("audio/basic".parse().unwrap()), + ]))); + test_header!( + test2, + vec![b"text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c"], + Some(HeaderField(vec![ + QualityItem::new(mime::TEXT_PLAIN, q(500)), + qitem(mime::TEXT_HTML), + QualityItem::new( + "text/x-dvi".parse().unwrap(), + q(800)), + qitem("text/x-c".parse().unwrap()), + ]))); + // Custom tests + test_header!( + test3, + vec![b"text/plain; charset=utf-8"], + Some(Accept(vec![ + qitem(mime::TEXT_PLAIN_UTF_8), + ]))); + test_header!( + test4, + vec![b"text/plain; charset=utf-8; q=0.5"], + Some(Accept(vec![ + QualityItem::new(mime::TEXT_PLAIN_UTF_8, + q(500)), + ]))); + + #[test] + fn test_fuzzing1() { + use crate::test::TestRequest; + let req = TestRequest::with_header(crate::header::ACCEPT, "chunk#;e").finish(); + let header = Accept::parse(&req); + assert!(header.is_ok()); + } + } +} + +impl Accept { + /// A constructor to easily create `Accept: */*`. + pub fn star() -> Accept { + Accept(vec![qitem(mime::STAR_STAR)]) + } + + /// A constructor to easily create `Accept: application/json`. + pub fn json() -> Accept { + Accept(vec![qitem(mime::APPLICATION_JSON)]) + } + + /// A constructor to easily create `Accept: text/*`. + pub fn text() -> Accept { + Accept(vec![qitem(mime::TEXT_STAR)]) + } + + /// A constructor to easily create `Accept: image/*`. + pub fn image() -> Accept { + Accept(vec![qitem(mime::IMAGE_STAR)]) + } +} diff --git a/actix-http/src/header/common/accept_charset.rs b/actix-http/src/header/common/accept_charset.rs new file mode 100644 index 000000000..117e2015d --- /dev/null +++ b/actix-http/src/header/common/accept_charset.rs @@ -0,0 +1,69 @@ +use crate::header::{Charset, QualityItem, ACCEPT_CHARSET}; + +header! { + /// `Accept-Charset` header, defined in + /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.3) + /// + /// The `Accept-Charset` header field can be sent by a user agent to + /// indicate what charsets are acceptable in textual response content. + /// This field allows user agents capable of understanding more + /// comprehensive or special-purpose charsets to signal that capability + /// to an origin server that is capable of representing information in + /// those charsets. + /// + /// # ABNF + /// + /// ```text + /// Accept-Charset = 1#( ( charset / "*" ) [ weight ] ) + /// ``` + /// + /// # Example values + /// * `iso-8859-5, unicode-1-1;q=0.8` + /// + /// # Examples + /// ```rust + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// AcceptCharset(vec![qitem(Charset::Us_Ascii)]) + /// ); + /// # } + /// ``` + /// ```rust + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptCharset, Charset, q, QualityItem}; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// AcceptCharset(vec![ + /// QualityItem::new(Charset::Us_Ascii, q(900)), + /// QualityItem::new(Charset::Iso_8859_10, q(200)), + /// ]) + /// ); + /// # } + /// ``` + /// ```rust + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptCharset, Charset, qitem}; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// AcceptCharset(vec![qitem(Charset::Ext("utf-8".to_owned()))]) + /// ); + /// # } + /// ``` + (AcceptCharset, ACCEPT_CHARSET) => (QualityItem)+ + + test_accept_charset { + /// Test case from RFC + test_header!(test1, vec![b"iso-8859-5, unicode-1-1;q=0.8"]); + } +} diff --git a/actix-http/src/header/common/accept_encoding.rs b/actix-http/src/header/common/accept_encoding.rs new file mode 100644 index 000000000..c90f529bc --- /dev/null +++ b/actix-http/src/header/common/accept_encoding.rs @@ -0,0 +1,72 @@ +use header::{Encoding, QualityItem}; + +header! { + /// `Accept-Encoding` header, defined in + /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.4) + /// + /// The `Accept-Encoding` header field can be used by user agents to + /// indicate what response content-codings are + /// acceptable in the response. An `identity` token is used as a synonym + /// for "no encoding" in order to communicate when no encoding is + /// preferred. + /// + /// # ABNF + /// + /// ```text + /// Accept-Encoding = #( codings [ weight ] ) + /// codings = content-coding / "identity" / "*" + /// ``` + /// + /// # Example values + /// * `compress, gzip` + /// * `` + /// * `*` + /// * `compress;q=0.5, gzip;q=1` + /// * `gzip;q=1.0, identity; q=0.5, *;q=0` + /// + /// # Examples + /// ``` + /// use hyper::header::{Headers, AcceptEncoding, Encoding, qitem}; + /// + /// let mut headers = Headers::new(); + /// headers.set( + /// AcceptEncoding(vec![qitem(Encoding::Chunked)]) + /// ); + /// ``` + /// ``` + /// use hyper::header::{Headers, AcceptEncoding, Encoding, qitem}; + /// + /// let mut headers = Headers::new(); + /// headers.set( + /// AcceptEncoding(vec![ + /// qitem(Encoding::Chunked), + /// qitem(Encoding::Gzip), + /// qitem(Encoding::Deflate), + /// ]) + /// ); + /// ``` + /// ``` + /// use hyper::header::{Headers, AcceptEncoding, Encoding, QualityItem, q, qitem}; + /// + /// let mut headers = Headers::new(); + /// headers.set( + /// AcceptEncoding(vec![ + /// qitem(Encoding::Chunked), + /// QualityItem::new(Encoding::Gzip, q(600)), + /// QualityItem::new(Encoding::EncodingExt("*".to_owned()), q(0)), + /// ]) + /// ); + /// ``` + (AcceptEncoding, "Accept-Encoding") => (QualityItem)* + + test_accept_encoding { + // From the RFC + test_header!(test1, vec![b"compress, gzip"]); + test_header!(test2, vec![b""], Some(AcceptEncoding(vec![]))); + test_header!(test3, vec![b"*"]); + // Note: Removed quality 1 from gzip + test_header!(test4, vec![b"compress;q=0.5, gzip"]); + // Note: Removed quality 1 from gzip + test_header!(test5, vec![b"gzip, identity; q=0.5, *;q=0"]); + } +} diff --git a/actix-http/src/header/common/accept_language.rs b/actix-http/src/header/common/accept_language.rs new file mode 100644 index 000000000..55879b57f --- /dev/null +++ b/actix-http/src/header/common/accept_language.rs @@ -0,0 +1,75 @@ +use crate::header::{QualityItem, ACCEPT_LANGUAGE}; +use language_tags::LanguageTag; + +header! { + /// `Accept-Language` header, defined in + /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-5.3.5) + /// + /// The `Accept-Language` header field can be used by user agents to + /// indicate the set of natural languages that are preferred in the + /// response. + /// + /// # ABNF + /// + /// ```text + /// Accept-Language = 1#( language-range [ weight ] ) + /// language-range = + /// ``` + /// + /// # Example values + /// * `da, en-gb;q=0.8, en;q=0.7` + /// * `en-us;q=1.0, en;q=0.5, fr` + /// + /// # Examples + /// + /// ```rust + /// # extern crate actix_http; + /// # extern crate language_tags; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptLanguage, LanguageTag, qitem}; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// let mut langtag: LanguageTag = Default::default(); + /// langtag.language = Some("en".to_owned()); + /// langtag.region = Some("US".to_owned()); + /// builder.set( + /// AcceptLanguage(vec![ + /// qitem(langtag), + /// ]) + /// ); + /// # } + /// ``` + /// + /// ```rust + /// # extern crate actix_http; + /// # #[macro_use] extern crate language_tags; + /// use actix_http::Response; + /// use actix_http::http::header::{AcceptLanguage, QualityItem, q, qitem}; + /// # + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// AcceptLanguage(vec![ + /// qitem(langtag!(da)), + /// QualityItem::new(langtag!(en;;;GB), q(800)), + /// QualityItem::new(langtag!(en), q(700)), + /// ]) + /// ); + /// # } + /// ``` + (AcceptLanguage, ACCEPT_LANGUAGE) => (QualityItem)+ + + test_accept_language { + // From the RFC + test_header!(test1, vec![b"da, en-gb;q=0.8, en;q=0.7"]); + // Own test + test_header!( + test2, vec![b"en-US, en; q=0.5, fr"], + Some(AcceptLanguage(vec![ + qitem("en-US".parse().unwrap()), + QualityItem::new("en".parse().unwrap(), q(500)), + qitem("fr".parse().unwrap()), + ]))); + } +} diff --git a/actix-http/src/header/common/allow.rs b/actix-http/src/header/common/allow.rs new file mode 100644 index 000000000..432cc00d5 --- /dev/null +++ b/actix-http/src/header/common/allow.rs @@ -0,0 +1,85 @@ +use http::Method; +use http::header; + +header! { + /// `Allow` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.4.1) + /// + /// The `Allow` header field lists the set of methods advertised as + /// supported by the target resource. The purpose of this field is + /// strictly to inform the recipient of valid request methods associated + /// with the resource. + /// + /// # ABNF + /// + /// ```text + /// Allow = #method + /// ``` + /// + /// # Example values + /// * `GET, HEAD, PUT` + /// * `OPTIONS, GET, PUT, POST, DELETE, HEAD, TRACE, CONNECT, PATCH, fOObAr` + /// * `` + /// + /// # Examples + /// + /// ```rust + /// # extern crate http; + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::Allow; + /// use http::Method; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// Allow(vec![Method::GET]) + /// ); + /// # } + /// ``` + /// + /// ```rust + /// # extern crate http; + /// # extern crate actix_http; + /// use actix_http::Response; + /// use actix_http::http::header::Allow; + /// use http::Method; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// Allow(vec![ + /// Method::GET, + /// Method::POST, + /// Method::PATCH, + /// ]) + /// ); + /// # } + /// ``` + (Allow, header::ALLOW) => (Method)* + + test_allow { + // From the RFC + test_header!( + test1, + vec![b"GET, HEAD, PUT"], + Some(HeaderField(vec![Method::GET, Method::HEAD, Method::PUT]))); + // Own tests + test_header!( + test2, + vec![b"OPTIONS, GET, PUT, POST, DELETE, HEAD, TRACE, CONNECT, PATCH"], + Some(HeaderField(vec![ + Method::OPTIONS, + Method::GET, + Method::PUT, + Method::POST, + Method::DELETE, + Method::HEAD, + Method::TRACE, + Method::CONNECT, + Method::PATCH]))); + test_header!( + test3, + vec![b""], + Some(HeaderField(Vec::::new()))); + } +} diff --git a/actix-http/src/header/common/cache_control.rs b/actix-http/src/header/common/cache_control.rs new file mode 100644 index 000000000..0b79ea7c0 --- /dev/null +++ b/actix-http/src/header/common/cache_control.rs @@ -0,0 +1,257 @@ +use std::fmt::{self, Write}; +use std::str::FromStr; + +use http::header; + +use crate::header::{ + fmt_comma_delimited, from_comma_delimited, Header, IntoHeaderValue, Writer, +}; + +/// `Cache-Control` header, defined in [RFC7234](https://tools.ietf.org/html/rfc7234#section-5.2) +/// +/// The `Cache-Control` header field is used to specify directives for +/// caches along the request/response chain. Such cache directives are +/// unidirectional in that the presence of a directive in a request does +/// not imply that the same directive is to be given in the response. +/// +/// # ABNF +/// +/// ```text +/// Cache-Control = 1#cache-directive +/// cache-directive = token [ "=" ( token / quoted-string ) ] +/// ``` +/// +/// # Example values +/// +/// * `no-cache` +/// * `private, community="UCI"` +/// * `max-age=30` +/// +/// # Examples +/// ```rust +/// use actix_http::Response; +/// use actix_http::http::header::{CacheControl, CacheDirective}; +/// +/// let mut builder = Response::Ok(); +/// builder.set(CacheControl(vec![CacheDirective::MaxAge(86400u32)])); +/// ``` +/// +/// ```rust +/// use actix_http::Response; +/// use actix_http::http::header::{CacheControl, CacheDirective}; +/// +/// let mut builder = Response::Ok(); +/// builder.set(CacheControl(vec![ +/// CacheDirective::NoCache, +/// CacheDirective::Private, +/// CacheDirective::MaxAge(360u32), +/// CacheDirective::Extension("foo".to_owned(), Some("bar".to_owned())), +/// ])); +/// ``` +#[derive(PartialEq, Clone, Debug)] +pub struct CacheControl(pub Vec); + +__hyper__deref!(CacheControl => Vec); + +//TODO: this could just be the header! macro +impl Header for CacheControl { + fn name() -> header::HeaderName { + header::CACHE_CONTROL + } + + #[inline] + fn parse(msg: &T) -> Result + where + T: crate::HttpMessage, + { + let directives = from_comma_delimited(msg.headers().get_all(Self::name()))?; + if !directives.is_empty() { + Ok(CacheControl(directives)) + } else { + Err(crate::error::ParseError::Header) + } + } +} + +impl fmt::Display for CacheControl { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt_comma_delimited(f, &self[..]) + } +} + +impl IntoHeaderValue for CacheControl { + type Error = header::InvalidHeaderValueBytes; + + fn try_into(self) -> Result { + let mut writer = Writer::new(); + let _ = write!(&mut writer, "{}", self); + header::HeaderValue::from_shared(writer.take()) + } +} + +/// `CacheControl` contains a list of these directives. +#[derive(PartialEq, Clone, Debug)] +pub enum CacheDirective { + /// "no-cache" + NoCache, + /// "no-store" + NoStore, + /// "no-transform" + NoTransform, + /// "only-if-cached" + OnlyIfCached, + + // request directives + /// "max-age=delta" + MaxAge(u32), + /// "max-stale=delta" + MaxStale(u32), + /// "min-fresh=delta" + MinFresh(u32), + + // response directives + /// "must-revalidate" + MustRevalidate, + /// "public" + Public, + /// "private" + Private, + /// "proxy-revalidate" + ProxyRevalidate, + /// "s-maxage=delta" + SMaxAge(u32), + + /// Extension directives. Optionally include an argument. + Extension(String, Option), +} + +impl fmt::Display for CacheDirective { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::CacheDirective::*; + fmt::Display::fmt( + match *self { + NoCache => "no-cache", + NoStore => "no-store", + NoTransform => "no-transform", + OnlyIfCached => "only-if-cached", + + MaxAge(secs) => return write!(f, "max-age={}", secs), + MaxStale(secs) => return write!(f, "max-stale={}", secs), + MinFresh(secs) => return write!(f, "min-fresh={}", secs), + + MustRevalidate => "must-revalidate", + Public => "public", + Private => "private", + ProxyRevalidate => "proxy-revalidate", + SMaxAge(secs) => return write!(f, "s-maxage={}", secs), + + Extension(ref name, None) => &name[..], + Extension(ref name, Some(ref arg)) => { + return write!(f, "{}={}", name, arg); + } + }, + f, + ) + } +} + +impl FromStr for CacheDirective { + type Err = Option<::Err>; + fn from_str(s: &str) -> Result::Err>> { + use self::CacheDirective::*; + match s { + "no-cache" => Ok(NoCache), + "no-store" => Ok(NoStore), + "no-transform" => Ok(NoTransform), + "only-if-cached" => Ok(OnlyIfCached), + "must-revalidate" => Ok(MustRevalidate), + "public" => Ok(Public), + "private" => Ok(Private), + "proxy-revalidate" => Ok(ProxyRevalidate), + "" => Err(None), + _ => match s.find('=') { + Some(idx) if idx + 1 < s.len() => { + match (&s[..idx], (&s[idx + 1..]).trim_matches('"')) { + ("max-age", secs) => secs.parse().map(MaxAge).map_err(Some), + ("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some), + ("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some), + ("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some), + (left, right) => { + Ok(Extension(left.to_owned(), Some(right.to_owned()))) + } + } + } + Some(_) => Err(None), + None => Ok(Extension(s.to_owned(), None)), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::header::Header; + use crate::test::TestRequest; + + #[test] + fn test_parse_multiple_headers() { + let req = TestRequest::with_header(header::CACHE_CONTROL, "no-cache, private") + .finish(); + let cache = Header::parse(&req); + assert_eq!( + cache.ok(), + Some(CacheControl(vec![ + CacheDirective::NoCache, + CacheDirective::Private, + ])) + ) + } + + #[test] + fn test_parse_argument() { + let req = + TestRequest::with_header(header::CACHE_CONTROL, "max-age=100, private") + .finish(); + let cache = Header::parse(&req); + assert_eq!( + cache.ok(), + Some(CacheControl(vec![ + CacheDirective::MaxAge(100), + CacheDirective::Private, + ])) + ) + } + + #[test] + fn test_parse_quote_form() { + let req = + TestRequest::with_header(header::CACHE_CONTROL, "max-age=\"200\"").finish(); + let cache = Header::parse(&req); + assert_eq!( + cache.ok(), + Some(CacheControl(vec![CacheDirective::MaxAge(200)])) + ) + } + + #[test] + fn test_parse_extension() { + let req = + TestRequest::with_header(header::CACHE_CONTROL, "foo, bar=baz").finish(); + let cache = Header::parse(&req); + assert_eq!( + cache.ok(), + Some(CacheControl(vec![ + CacheDirective::Extension("foo".to_owned(), None), + CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())), + ])) + ) + } + + #[test] + fn test_parse_bad_syntax() { + let req = TestRequest::with_header(header::CACHE_CONTROL, "foo=").finish(); + let cache: Result = Header::parse(&req); + assert_eq!(cache.ok(), None) + } +} diff --git a/actix-http/src/header/common/content_disposition.rs b/actix-http/src/header/common/content_disposition.rs new file mode 100644 index 000000000..700400dad --- /dev/null +++ b/actix-http/src/header/common/content_disposition.rs @@ -0,0 +1,918 @@ +// # References +// +// "The Content-Disposition Header Field" https://www.ietf.org/rfc/rfc2183.txt +// "The Content-Disposition Header Field in the Hypertext Transfer Protocol (HTTP)" https://www.ietf.org/rfc/rfc6266.txt +// "Returning Values from Forms: multipart/form-data" https://www.ietf.org/rfc/rfc7578.txt +// Browser conformance tests at: http://greenbytes.de/tech/tc2231/ +// IANA assignment: http://www.iana.org/assignments/cont-disp/cont-disp.xhtml + +use lazy_static::lazy_static; +use regex::Regex; +use std::fmt::{self, Write}; + +use crate::header::{self, ExtendedValue, Header, IntoHeaderValue, Writer}; + +/// Split at the index of the first `needle` if it exists or at the end. +fn split_once(haystack: &str, needle: char) -> (&str, &str) { + haystack.find(needle).map_or_else( + || (haystack, ""), + |sc| { + let (first, last) = haystack.split_at(sc); + (first, last.split_at(1).1) + }, + ) +} + +/// Split at the index of the first `needle` if it exists or at the end, trim the right of the +/// first part and the left of the last part. +fn split_once_and_trim(haystack: &str, needle: char) -> (&str, &str) { + let (first, last) = split_once(haystack, needle); + (first.trim_end(), last.trim_start()) +} + +/// The implied disposition of the content of the HTTP body. +#[derive(Clone, Debug, PartialEq)] +pub enum DispositionType { + /// Inline implies default processing + Inline, + /// Attachment implies that the recipient should prompt the user to save the response locally, + /// rather than process it normally (as per its media type). + Attachment, + /// Used in *multipart/form-data* as defined in + /// [RFC7578](https://tools.ietf.org/html/rfc7578) to carry the field name and the file name. + FormData, + /// Extension type. Should be handled by recipients the same way as Attachment + Ext(String), +} + +impl<'a> From<&'a str> for DispositionType { + fn from(origin: &'a str) -> DispositionType { + if origin.eq_ignore_ascii_case("inline") { + DispositionType::Inline + } else if origin.eq_ignore_ascii_case("attachment") { + DispositionType::Attachment + } else if origin.eq_ignore_ascii_case("form-data") { + DispositionType::FormData + } else { + DispositionType::Ext(origin.to_owned()) + } + } +} + +/// Parameter in [`ContentDisposition`]. +/// +/// # Examples +/// ``` +/// use actix_http::http::header::DispositionParam; +/// +/// let param = DispositionParam::Filename(String::from("sample.txt")); +/// assert!(param.is_filename()); +/// assert_eq!(param.as_filename().unwrap(), "sample.txt"); +/// ``` +#[derive(Clone, Debug, PartialEq)] +pub enum DispositionParam { + /// For [`DispositionType::FormData`] (i.e. *multipart/form-data*), the name of an field from + /// the form. + Name(String), + /// A plain file name. + Filename(String), + /// An extended file name. It must not exist for `ContentType::Formdata` according to + /// [RFC7578 Section 4.2](https://tools.ietf.org/html/rfc7578#section-4.2). + FilenameExt(ExtendedValue), + /// An unrecognized regular parameter as defined in + /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *reg-parameter*, in + /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *token "=" value*. Recipients should + /// ignore unrecognizable parameters. + Unknown(String, String), + /// An unrecognized extended paramater as defined in + /// [RFC5987](https://tools.ietf.org/html/rfc5987) as *ext-parameter*, in + /// [RFC6266](https://tools.ietf.org/html/rfc6266) as *ext-token "=" ext-value*. The single + /// trailling asterisk is not included. Recipients should ignore unrecognizable parameters. + UnknownExt(String, ExtendedValue), +} + +impl DispositionParam { + /// Returns `true` if the paramater is [`Name`](DispositionParam::Name). + #[inline] + pub fn is_name(&self) -> bool { + self.as_name().is_some() + } + + /// Returns `true` if the paramater is [`Filename`](DispositionParam::Filename). + #[inline] + pub fn is_filename(&self) -> bool { + self.as_filename().is_some() + } + + /// Returns `true` if the paramater is [`FilenameExt`](DispositionParam::FilenameExt). + #[inline] + pub fn is_filename_ext(&self) -> bool { + self.as_filename_ext().is_some() + } + + /// Returns `true` if the paramater is [`Unknown`](DispositionParam::Unknown) and the `name` + #[inline] + /// matches. + pub fn is_unknown>(&self, name: T) -> bool { + self.as_unknown(name).is_some() + } + + /// Returns `true` if the paramater is [`UnknownExt`](DispositionParam::UnknownExt) and the + /// `name` matches. + #[inline] + pub fn is_unknown_ext>(&self, name: T) -> bool { + self.as_unknown_ext(name).is_some() + } + + /// Returns the name if applicable. + #[inline] + pub fn as_name(&self) -> Option<&str> { + match self { + DispositionParam::Name(ref name) => Some(name.as_str()), + _ => None, + } + } + + /// Returns the filename if applicable. + #[inline] + pub fn as_filename(&self) -> Option<&str> { + match self { + DispositionParam::Filename(ref filename) => Some(filename.as_str()), + _ => None, + } + } + + /// Returns the filename* if applicable. + #[inline] + pub fn as_filename_ext(&self) -> Option<&ExtendedValue> { + match self { + DispositionParam::FilenameExt(ref value) => Some(value), + _ => None, + } + } + + /// Returns the value of the unrecognized regular parameter if it is + /// [`Unknown`](DispositionParam::Unknown) and the `name` matches. + #[inline] + pub fn as_unknown>(&self, name: T) -> Option<&str> { + match self { + DispositionParam::Unknown(ref ext_name, ref value) + if ext_name.eq_ignore_ascii_case(name.as_ref()) => + { + Some(value.as_str()) + } + _ => None, + } + } + + /// Returns the value of the unrecognized extended parameter if it is + /// [`Unknown`](DispositionParam::Unknown) and the `name` matches. + #[inline] + pub fn as_unknown_ext>(&self, name: T) -> Option<&ExtendedValue> { + match self { + DispositionParam::UnknownExt(ref ext_name, ref value) + if ext_name.eq_ignore_ascii_case(name.as_ref()) => + { + Some(value) + } + _ => None, + } + } +} + +/// A *Content-Disposition* header. It is compatible to be used either as +/// [a response header for the main body](https://mdn.io/Content-Disposition#As_a_response_header_for_the_main_body) +/// as (re)defined in [RFC6266](https://tools.ietf.org/html/rfc6266), or as +/// [a header for a multipart body](https://mdn.io/Content-Disposition#As_a_header_for_a_multipart_body) +/// as (re)defined in [RFC7587](https://tools.ietf.org/html/rfc7578). +/// +/// In a regular HTTP response, the *Content-Disposition* response header is a header indicating if +/// the content is expected to be displayed *inline* in the browser, that is, as a Web page or as +/// part of a Web page, or as an attachment, that is downloaded and saved locally, and also can be +/// used to attach additional metadata, such as the filename to use when saving the response payload +/// locally. +/// +/// In a *multipart/form-data* body, the HTTP *Content-Disposition* general header is a header that +/// can be used on the subpart of a multipart body to give information about the field it applies to. +/// The subpart is delimited by the boundary defined in the *Content-Type* header. Used on the body +/// itself, *Content-Disposition* has no effect. +/// +/// # ABNF + +/// ```text +/// content-disposition = "Content-Disposition" ":" +/// disposition-type *( ";" disposition-parm ) +/// +/// disposition-type = "inline" | "attachment" | disp-ext-type +/// ; case-insensitive +/// +/// disp-ext-type = token +/// +/// disposition-parm = filename-parm | disp-ext-parm +/// +/// filename-parm = "filename" "=" value +/// | "filename*" "=" ext-value +/// +/// disp-ext-parm = token "=" value +/// | ext-token "=" ext-value +/// +/// ext-token = +/// ``` +/// +/// **Note**: filename* [must not](https://tools.ietf.org/html/rfc7578#section-4.2) be used within +/// *multipart/form-data*. +/// +/// # Example +/// +/// ``` +/// use actix_http::http::header::{ +/// Charset, ContentDisposition, DispositionParam, DispositionType, +/// ExtendedValue, +/// }; +/// +/// let cd1 = ContentDisposition { +/// disposition: DispositionType::Attachment, +/// parameters: vec![DispositionParam::FilenameExt(ExtendedValue { +/// charset: Charset::Iso_8859_1, // The character set for the bytes of the filename +/// language_tag: None, // The optional language tag (see `language-tag` crate) +/// value: b"\xa9 Copyright 1989.txt".to_vec(), // the actual bytes of the filename +/// })], +/// }; +/// assert!(cd1.is_attachment()); +/// assert!(cd1.get_filename_ext().is_some()); +/// +/// let cd2 = ContentDisposition { +/// disposition: DispositionType::FormData, +/// parameters: vec![ +/// DispositionParam::Name(String::from("file")), +/// DispositionParam::Filename(String::from("bill.odt")), +/// ], +/// }; +/// assert_eq!(cd2.get_name(), Some("file")); // field name +/// assert_eq!(cd2.get_filename(), Some("bill.odt")); +/// ``` +/// +/// # WARN +/// If "filename" parameter is supplied, do not use the file name blindly, check and possibly +/// change to match local file system conventions if applicable, and do not use directory path +/// information that may be present. See [RFC2183](https://tools.ietf.org/html/rfc2183#section-2.3) +/// . +#[derive(Clone, Debug, PartialEq)] +pub struct ContentDisposition { + /// The disposition type + pub disposition: DispositionType, + /// Disposition parameters + pub parameters: Vec, +} + +impl ContentDisposition { + /// Parse a raw Content-Disposition header value. + pub fn from_raw(hv: &header::HeaderValue) -> Result { + // `header::from_one_raw_str` invokes `hv.to_str` which assumes `hv` contains only visible + // ASCII characters. So `hv.as_bytes` is necessary here. + let hv = String::from_utf8(hv.as_bytes().to_vec()) + .map_err(|_| crate::error::ParseError::Header)?; + let (disp_type, mut left) = split_once_and_trim(hv.as_str().trim(), ';'); + if disp_type.is_empty() { + return Err(crate::error::ParseError::Header); + } + let mut cd = ContentDisposition { + disposition: disp_type.into(), + parameters: Vec::new(), + }; + + while !left.is_empty() { + let (param_name, new_left) = split_once_and_trim(left, '='); + if param_name.is_empty() || param_name == "*" || new_left.is_empty() { + return Err(crate::error::ParseError::Header); + } + left = new_left; + if param_name.ends_with('*') { + // extended parameters + let param_name = ¶m_name[..param_name.len() - 1]; // trim asterisk + let (ext_value, new_left) = split_once_and_trim(left, ';'); + left = new_left; + let ext_value = header::parse_extended_value(ext_value)?; + + let param = if param_name.eq_ignore_ascii_case("filename") { + DispositionParam::FilenameExt(ext_value) + } else { + DispositionParam::UnknownExt(param_name.to_owned(), ext_value) + }; + cd.parameters.push(param); + } else { + // regular parameters + let value = if left.starts_with('\"') { + // quoted-string: defined in RFC6266 -> RFC2616 Section 3.6 + let mut escaping = false; + let mut quoted_string = vec![]; + let mut end = None; + // search for closing quote + for (i, &c) in left.as_bytes().iter().skip(1).enumerate() { + if escaping { + escaping = false; + quoted_string.push(c); + } else if c == 0x5c { + // backslash + escaping = true; + } else if c == 0x22 { + // double quote + end = Some(i + 1); // cuz skipped 1 for the leading quote + break; + } else { + quoted_string.push(c); + } + } + left = &left[end.ok_or(crate::error::ParseError::Header)? + 1..]; + left = split_once(left, ';').1.trim_start(); + // In fact, it should not be Err if the above code is correct. + String::from_utf8(quoted_string) + .map_err(|_| crate::error::ParseError::Header)? + } else { + // token: won't contains semicolon according to RFC 2616 Section 2.2 + let (token, new_left) = split_once_and_trim(left, ';'); + left = new_left; + token.to_owned() + }; + if value.is_empty() { + return Err(crate::error::ParseError::Header); + } + + let param = if param_name.eq_ignore_ascii_case("name") { + DispositionParam::Name(value) + } else if param_name.eq_ignore_ascii_case("filename") { + DispositionParam::Filename(value) + } else { + DispositionParam::Unknown(param_name.to_owned(), value) + }; + cd.parameters.push(param); + } + } + + Ok(cd) + } + + /// Returns `true` if it is [`Inline`](DispositionType::Inline). + pub fn is_inline(&self) -> bool { + match self.disposition { + DispositionType::Inline => true, + _ => false, + } + } + + /// Returns `true` if it is [`Attachment`](DispositionType::Attachment). + pub fn is_attachment(&self) -> bool { + match self.disposition { + DispositionType::Attachment => true, + _ => false, + } + } + + /// Returns `true` if it is [`FormData`](DispositionType::FormData). + pub fn is_form_data(&self) -> bool { + match self.disposition { + DispositionType::FormData => true, + _ => false, + } + } + + /// Returns `true` if it is [`Ext`](DispositionType::Ext) and the `disp_type` matches. + pub fn is_ext>(&self, disp_type: T) -> bool { + match self.disposition { + DispositionType::Ext(ref t) + if t.eq_ignore_ascii_case(disp_type.as_ref()) => + { + true + } + _ => false, + } + } + + /// Return the value of *name* if exists. + pub fn get_name(&self) -> Option<&str> { + self.parameters.iter().filter_map(|p| p.as_name()).nth(0) + } + + /// Return the value of *filename* if exists. + pub fn get_filename(&self) -> Option<&str> { + self.parameters + .iter() + .filter_map(|p| p.as_filename()) + .nth(0) + } + + /// Return the value of *filename\** if exists. + pub fn get_filename_ext(&self) -> Option<&ExtendedValue> { + self.parameters + .iter() + .filter_map(|p| p.as_filename_ext()) + .nth(0) + } + + /// Return the value of the parameter which the `name` matches. + pub fn get_unknown>(&self, name: T) -> Option<&str> { + let name = name.as_ref(); + self.parameters + .iter() + .filter_map(|p| p.as_unknown(name)) + .nth(0) + } + + /// Return the value of the extended parameter which the `name` matches. + pub fn get_unknown_ext>(&self, name: T) -> Option<&ExtendedValue> { + let name = name.as_ref(); + self.parameters + .iter() + .filter_map(|p| p.as_unknown_ext(name)) + .nth(0) + } +} + +impl IntoHeaderValue for ContentDisposition { + type Error = header::InvalidHeaderValueBytes; + + fn try_into(self) -> Result { + let mut writer = Writer::new(); + let _ = write!(&mut writer, "{}", self); + header::HeaderValue::from_shared(writer.take()) + } +} + +impl Header for ContentDisposition { + fn name() -> header::HeaderName { + header::CONTENT_DISPOSITION + } + + fn parse(msg: &T) -> Result { + if let Some(h) = msg.headers().get(Self::name()) { + Self::from_raw(&h) + } else { + Err(crate::error::ParseError::Header) + } + } +} + +impl fmt::Display for DispositionType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DispositionType::Inline => write!(f, "inline"), + DispositionType::Attachment => write!(f, "attachment"), + DispositionType::FormData => write!(f, "form-data"), + DispositionType::Ext(ref s) => write!(f, "{}", s), + } + } +} + +impl fmt::Display for DispositionParam { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // All ASCII control charaters (0-30, 127) excepting horizontal tab, double quote, and + // backslash should be escaped in quoted-string (i.e. "foobar"). + // Ref: RFC6266 S4.1 -> RFC2616 S2.2; RFC 7578 S4.2 -> RFC2183 S2 -> ... . + lazy_static! { + static ref RE: Regex = Regex::new("[\x01-\x08\x10\x1F\x7F\"\\\\]").unwrap(); + } + match self { + DispositionParam::Name(ref value) => write!(f, "name={}", value), + DispositionParam::Filename(ref value) => { + write!(f, "filename=\"{}\"", RE.replace_all(value, "\\$0").as_ref()) + } + DispositionParam::Unknown(ref name, ref value) => write!( + f, + "{}=\"{}\"", + name, + &RE.replace_all(value, "\\$0").as_ref() + ), + DispositionParam::FilenameExt(ref ext_value) => { + write!(f, "filename*={}", ext_value) + } + DispositionParam::UnknownExt(ref name, ref ext_value) => { + write!(f, "{}*={}", name, ext_value) + } + } + } +} + +impl fmt::Display for ContentDisposition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.disposition)?; + self.parameters + .iter() + .map(|param| write!(f, "; {}", param)) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::{ContentDisposition, DispositionParam, DispositionType}; + use crate::header::shared::Charset; + use crate::header::{ExtendedValue, HeaderValue}; + + #[test] + fn test_from_raw_basic() { + assert!(ContentDisposition::from_raw(&HeaderValue::from_static("")).is_err()); + + let a = HeaderValue::from_static( + "form-data; dummy=3; name=upload; filename=\"sample.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static("attachment; filename=\"image.jpg\""); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::Filename("image.jpg".to_owned())], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static("inline; filename=image.jpg"); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename("image.jpg".to_owned())], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "attachment; creation-date=\"Wed, 12 Feb 1997 16:29:51 -0500\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::Unknown( + String::from("creation-date"), + "Wed, 12 Feb 1997 16:29:51 -0500".to_owned(), + )], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_extended() { + let a = HeaderValue::from_static( + "attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: vec![ + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, + b'r', b'a', b't', b'e', b's', + ], + })], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "attachment; filename*=UTF-8''%c2%a3%20and%20%e2%82%ac%20rates", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Ext(String::from("UTF-8")), + language_tag: None, + value: vec![ + 0xc2, 0xa3, 0x20, b'a', b'n', b'd', 0x20, 0xe2, 0x82, 0xac, 0x20, + b'r', b'a', b't', b'e', b's', + ], + })], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_extra_whitespace() { + let a = HeaderValue::from_static( + "form-data ; du-mmy= 3 ; name =upload ; filename = \"sample.png\" ; ", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("du-mmy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_unordered() { + let a = HeaderValue::from_static( + "form-data; dummy=3; filename=\"sample.png\" ; name=upload;", + // Actually, a trailling semolocon is not compliant. But it is fine to accept. + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + DispositionParam::Name("upload".to_owned()), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_str( + "attachment; filename*=iso-8859-1''foo-%E4.html; filename=\"foo-ä.html\"", + ) + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![ + DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: None, + value: b"foo-\xe4.html".to_vec(), + }), + DispositionParam::Filename("foo-ä.html".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_only_disp() { + let a = ContentDisposition::from_raw(&HeaderValue::from_static("attachment")) + .unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Attachment, + parameters: vec![], + }; + assert_eq!(a, b); + + let a = + ContentDisposition::from_raw(&HeaderValue::from_static("inline ;")).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![], + }; + assert_eq!(a, b); + + let a = ContentDisposition::from_raw(&HeaderValue::from_static( + "unknown-disp-param", + )) + .unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Ext(String::from("unknown-disp-param")), + parameters: vec![], + }; + assert_eq!(a, b); + } + + #[test] + fn from_raw_with_mixed_case() { + let a = HeaderValue::from_str( + "InLInE; fIlenAME*=iso-8859-1''foo-%E4.html; filEName=\"foo-ä.html\"", + ) + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![ + DispositionParam::FilenameExt(ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: None, + value: b"foo-\xe4.html".to_vec(), + }), + DispositionParam::Filename("foo-ä.html".to_owned()), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn from_raw_with_unicode() { + /* RFC7578 Section 4.2: + Some commonly deployed systems use multipart/form-data with file names directly encoded + including octets outside the US-ASCII range. The encoding used for the file names is + typically UTF-8, although HTML forms will use the charset associated with the form. + + Mainstream browsers like Firefox (gecko) and Chrome use UTF-8 directly as above. + (And now, only UTF-8 is handled by this implementation.) + */ + let a = + HeaderValue::from_str("form-data; name=upload; filename=\"文件.webp\"") + .unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name(String::from("upload")), + DispositionParam::Filename(String::from("文件.webp")), + ], + }; + assert_eq!(a, b); + + let a = + HeaderValue::from_str("form-data; name=upload; filename=\"余固知謇謇之為患兮,忍而不能舍也.pptx\"").unwrap(); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name(String::from("upload")), + DispositionParam::Filename(String::from( + "余固知謇謇之為患兮,忍而不能舍也.pptx", + )), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_escape() { + let a = HeaderValue::from_static( + "form-data; dummy=3; name=upload; filename=\"s\\amp\\\"le.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename( + ['s', 'a', 'm', 'p', '\"', 'l', 'e', '.', 'p', 'n', 'g'] + .iter() + .collect(), + ), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_semicolon() { + let a = + HeaderValue::from_static("form-data; filename=\"A semicolon here;.pdf\""); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![DispositionParam::Filename(String::from( + "A semicolon here;.pdf", + ))], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_uncessary_percent_decode() { + let a = HeaderValue::from_static( + "form-data; name=photo; filename=\"%74%65%73%74%2e%70%6e%67\"", // Should not be decoded! + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name("photo".to_owned()), + DispositionParam::Filename(String::from("%74%65%73%74%2e%70%6e%67")), + ], + }; + assert_eq!(a, b); + + let a = HeaderValue::from_static( + "form-data; name=photo; filename=\"%74%65%73%74.png\"", + ); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let b = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Name("photo".to_owned()), + DispositionParam::Filename(String::from("%74%65%73%74.png")), + ], + }; + assert_eq!(a, b); + } + + #[test] + fn test_from_raw_param_value_missing() { + let a = HeaderValue::from_static("form-data; name=upload ; filename="); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("attachment; dummy=; filename=invoice.pdf"); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; filename= "); + assert!(ContentDisposition::from_raw(&a).is_err()); + } + + #[test] + fn test_from_raw_param_name_missing() { + let a = HeaderValue::from_static("inline; =\"test.txt\""); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; =diary.odt"); + assert!(ContentDisposition::from_raw(&a).is_err()); + + let a = HeaderValue::from_static("inline; ="); + assert!(ContentDisposition::from_raw(&a).is_err()); + } + + #[test] + fn test_display_extended() { + let as_string = + "attachment; filename*=UTF-8'en'%C2%A3%20and%20%E2%82%AC%20rates"; + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + + let a = HeaderValue::from_static("attachment; filename=colourful.csv"); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!( + "attachment; filename=\"colourful.csv\"".to_owned(), + display_rendered + ); + } + + #[test] + fn test_display_quote() { + let as_string = "form-data; name=upload; filename=\"Quote\\\"here.png\""; + as_string + .find(['\\', '\"'].iter().collect::().as_str()) + .unwrap(); // ensure `\"` is there + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + } + + #[test] + fn test_display_space_tab() { + let as_string = "form-data; name=upload; filename=\"Space here.png\""; + let a = HeaderValue::from_static(as_string); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!(as_string, display_rendered); + + let a: ContentDisposition = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename(String::from("Tab\there.png"))], + }; + let display_rendered = format!("{}", a); + assert_eq!("inline; filename=\"Tab\x09here.png\"", display_rendered); + } + + #[test] + fn test_display_control_characters() { + /* let a = "attachment; filename=\"carriage\rreturn.png\""; + let a = HeaderValue::from_static(a); + let a: ContentDisposition = ContentDisposition::from_raw(&a).unwrap(); + let display_rendered = format!("{}", a); + assert_eq!( + "attachment; filename=\"carriage\\\rreturn.png\"", + display_rendered + );*/ + // No way to create a HeaderValue containing a carriage return. + + let a: ContentDisposition = ContentDisposition { + disposition: DispositionType::Inline, + parameters: vec![DispositionParam::Filename(String::from("bell\x07.png"))], + }; + let display_rendered = format!("{}", a); + assert_eq!("inline; filename=\"bell\\\x07.png\"", display_rendered); + } + + #[test] + fn test_param_methods() { + let param = DispositionParam::Filename(String::from("sample.txt")); + assert!(param.is_filename()); + assert_eq!(param.as_filename().unwrap(), "sample.txt"); + + let param = DispositionParam::Unknown(String::from("foo"), String::from("bar")); + assert!(param.is_unknown("foo")); + assert_eq!(param.as_unknown("fOo"), Some("bar")); + } + + #[test] + fn test_disposition_methods() { + let cd = ContentDisposition { + disposition: DispositionType::FormData, + parameters: vec![ + DispositionParam::Unknown("dummy".to_owned(), "3".to_owned()), + DispositionParam::Name("upload".to_owned()), + DispositionParam::Filename("sample.png".to_owned()), + ], + }; + assert_eq!(cd.get_name(), Some("upload")); + assert_eq!(cd.get_unknown("dummy"), Some("3")); + assert_eq!(cd.get_filename(), Some("sample.png")); + assert_eq!(cd.get_unknown_ext("dummy"), None); + assert_eq!(cd.get_unknown("duMMy"), Some("3")); + } +} diff --git a/actix-http/src/header/common/content_language.rs b/actix-http/src/header/common/content_language.rs new file mode 100644 index 000000000..838981a39 --- /dev/null +++ b/actix-http/src/header/common/content_language.rs @@ -0,0 +1,65 @@ +use crate::header::{QualityItem, CONTENT_LANGUAGE}; +use language_tags::LanguageTag; + +header! { + /// `Content-Language` header, defined in + /// [RFC7231](https://tools.ietf.org/html/rfc7231#section-3.1.3.2) + /// + /// The `Content-Language` header field describes the natural language(s) + /// of the intended audience for the representation. Note that this + /// might not be equivalent to all the languages used within the + /// representation. + /// + /// # ABNF + /// + /// ```text + /// Content-Language = 1#language-tag + /// ``` + /// + /// # Example values + /// + /// * `da` + /// * `mi, en` + /// + /// # Examples + /// + /// ```rust + /// # extern crate actix_http; + /// # #[macro_use] extern crate language_tags; + /// use actix_http::Response; + /// # use actix_http::http::header::{ContentLanguage, qitem}; + /// # + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// ContentLanguage(vec![ + /// qitem(langtag!(en)), + /// ]) + /// ); + /// # } + /// ``` + /// + /// ```rust + /// # extern crate actix_http; + /// # #[macro_use] extern crate language_tags; + /// use actix_http::Response; + /// # use actix_http::http::header::{ContentLanguage, qitem}; + /// # + /// # fn main() { + /// + /// let mut builder = Response::Ok(); + /// builder.set( + /// ContentLanguage(vec![ + /// qitem(langtag!(da)), + /// qitem(langtag!(en;;;GB)), + /// ]) + /// ); + /// # } + /// ``` + (ContentLanguage, CONTENT_LANGUAGE) => (QualityItem)+ + + test_content_language { + test_header!(test1, vec![b"da"]); + test_header!(test2, vec![b"mi, en"]); + } +} diff --git a/actix-http/src/header/common/content_range.rs b/actix-http/src/header/common/content_range.rs new file mode 100644 index 000000000..cc7f27548 --- /dev/null +++ b/actix-http/src/header/common/content_range.rs @@ -0,0 +1,208 @@ +use std::fmt::{self, Display, Write}; +use std::str::FromStr; + +use crate::error::ParseError; +use crate::header::{ + HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer, CONTENT_RANGE, +}; + +header! { + /// `Content-Range` header, defined in + /// [RFC7233](http://tools.ietf.org/html/rfc7233#section-4.2) + (ContentRange, CONTENT_RANGE) => [ContentRangeSpec] + + test_content_range { + test_header!(test_bytes, + vec![b"bytes 0-499/500"], + Some(ContentRange(ContentRangeSpec::Bytes { + range: Some((0, 499)), + instance_length: Some(500) + }))); + + test_header!(test_bytes_unknown_len, + vec![b"bytes 0-499/*"], + Some(ContentRange(ContentRangeSpec::Bytes { + range: Some((0, 499)), + instance_length: None + }))); + + test_header!(test_bytes_unknown_range, + vec![b"bytes */500"], + Some(ContentRange(ContentRangeSpec::Bytes { + range: None, + instance_length: Some(500) + }))); + + test_header!(test_unregistered, + vec![b"seconds 1-2"], + Some(ContentRange(ContentRangeSpec::Unregistered { + unit: "seconds".to_owned(), + resp: "1-2".to_owned() + }))); + + test_header!(test_no_len, + vec![b"bytes 0-499"], + None::); + + test_header!(test_only_unit, + vec![b"bytes"], + None::); + + test_header!(test_end_less_than_start, + vec![b"bytes 499-0/500"], + None::); + + test_header!(test_blank, + vec![b""], + None::); + + test_header!(test_bytes_many_spaces, + vec![b"bytes 1-2/500 3"], + None::); + + test_header!(test_bytes_many_slashes, + vec![b"bytes 1-2/500/600"], + None::); + + test_header!(test_bytes_many_dashes, + vec![b"bytes 1-2-3/500"], + None::); + + } +} + +/// Content-Range, described in [RFC7233](https://tools.ietf.org/html/rfc7233#section-4.2) +/// +/// # ABNF +/// +/// ```text +/// Content-Range = byte-content-range +/// / other-content-range +/// +/// byte-content-range = bytes-unit SP +/// ( byte-range-resp / unsatisfied-range ) +/// +/// byte-range-resp = byte-range "/" ( complete-length / "*" ) +/// byte-range = first-byte-pos "-" last-byte-pos +/// unsatisfied-range = "*/" complete-length +/// +/// complete-length = 1*DIGIT +/// +/// other-content-range = other-range-unit SP other-range-resp +/// other-range-resp = *CHAR +/// ``` +#[derive(PartialEq, Clone, Debug)] +pub enum ContentRangeSpec { + /// Byte range + Bytes { + /// First and last bytes of the range, omitted if request could not be + /// satisfied + range: Option<(u64, u64)>, + + /// Total length of the instance, can be omitted if unknown + instance_length: Option, + }, + + /// Custom range, with unit not registered at IANA + Unregistered { + /// other-range-unit + unit: String, + + /// other-range-resp + resp: String, + }, +} + +fn split_in_two(s: &str, separator: char) -> Option<(&str, &str)> { + let mut iter = s.splitn(2, separator); + match (iter.next(), iter.next()) { + (Some(a), Some(b)) => Some((a, b)), + _ => None, + } +} + +impl FromStr for ContentRangeSpec { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + let res = match split_in_two(s, ' ') { + Some(("bytes", resp)) => { + let (range, instance_length) = + split_in_two(resp, '/').ok_or(ParseError::Header)?; + + let instance_length = if instance_length == "*" { + None + } else { + Some(instance_length.parse().map_err(|_| ParseError::Header)?) + }; + + let range = if range == "*" { + None + } else { + let (first_byte, last_byte) = + split_in_two(range, '-').ok_or(ParseError::Header)?; + let first_byte = + first_byte.parse().map_err(|_| ParseError::Header)?; + let last_byte = last_byte.parse().map_err(|_| ParseError::Header)?; + if last_byte < first_byte { + return Err(ParseError::Header); + } + Some((first_byte, last_byte)) + }; + + ContentRangeSpec::Bytes { + range, + instance_length, + } + } + Some((unit, resp)) => ContentRangeSpec::Unregistered { + unit: unit.to_owned(), + resp: resp.to_owned(), + }, + _ => return Err(ParseError::Header), + }; + Ok(res) + } +} + +impl Display for ContentRangeSpec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ContentRangeSpec::Bytes { + range, + instance_length, + } => { + f.write_str("bytes ")?; + match range { + Some((first_byte, last_byte)) => { + write!(f, "{}-{}", first_byte, last_byte)?; + } + None => { + f.write_str("*")?; + } + }; + f.write_str("/")?; + if let Some(v) = instance_length { + write!(f, "{}", v) + } else { + f.write_str("*") + } + } + ContentRangeSpec::Unregistered { ref unit, ref resp } => { + f.write_str(unit)?; + f.write_str(" ")?; + f.write_str(resp) + } + } + } +} + +impl IntoHeaderValue for ContentRangeSpec { + type Error = InvalidHeaderValueBytes; + + fn try_into(self) -> Result { + let mut writer = Writer::new(); + let _ = write!(&mut writer, "{}", self); + HeaderValue::from_shared(writer.take()) + } +} diff --git a/actix-http/src/header/common/content_type.rs b/actix-http/src/header/common/content_type.rs new file mode 100644 index 000000000..a0baa5637 --- /dev/null +++ b/actix-http/src/header/common/content_type.rs @@ -0,0 +1,122 @@ +use crate::header::CONTENT_TYPE; +use mime::Mime; + +header! { + /// `Content-Type` header, defined in + /// [RFC7231](http://tools.ietf.org/html/rfc7231#section-3.1.1.5) + /// + /// The `Content-Type` header field indicates the media type of the + /// associated representation: either the representation enclosed in the + /// message payload or the selected representation, as determined by the + /// message semantics. The indicated media type defines both the data + /// format and how that data is intended to be processed by a recipient, + /// within the scope of the received message semantics, after any content + /// codings indicated by Content-Encoding are decoded. + /// + /// Although the `mime` crate allows the mime options to be any slice, this crate + /// forces the use of Vec. This is to make sure the same header can't have more than 1 type. If + /// this is an issue, it's possible to implement `Header` on a custom struct. + /// + /// # ABNF + /// + /// ```text + /// Content-Type = media-type + /// ``` + /// + /// # Example values + /// + /// * `text/html; charset=utf-8` + /// * `application/json` + /// + /// # Examples + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::ContentType; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// ContentType::json() + /// ); + /// # } + /// ``` + /// + /// ```rust + /// # extern crate mime; + /// # extern crate actix_http; + /// use mime::TEXT_HTML; + /// use actix_http::Response; + /// use actix_http::http::header::ContentType; + /// + /// # fn main() { + /// let mut builder = Response::Ok(); + /// builder.set( + /// ContentType(TEXT_HTML) + /// ); + /// # } + /// ``` + (ContentType, CONTENT_TYPE) => [Mime] + + test_content_type { + test_header!( + test1, + vec![b"text/html"], + Some(HeaderField(mime::TEXT_HTML))); + } +} + +impl ContentType { + /// A constructor to easily create a `Content-Type: application/json` + /// header. + #[inline] + pub fn json() -> ContentType { + ContentType(mime::APPLICATION_JSON) + } + + /// A constructor to easily create a `Content-Type: text/plain; + /// charset=utf-8` header. + #[inline] + pub fn plaintext() -> ContentType { + ContentType(mime::TEXT_PLAIN_UTF_8) + } + + /// A constructor to easily create a `Content-Type: text/html` header. + #[inline] + pub fn html() -> ContentType { + ContentType(mime::TEXT_HTML) + } + + /// A constructor to easily create a `Content-Type: text/xml` header. + #[inline] + pub fn xml() -> ContentType { + ContentType(mime::TEXT_XML) + } + + /// A constructor to easily create a `Content-Type: + /// application/www-form-url-encoded` header. + #[inline] + pub fn form_url_encoded() -> ContentType { + ContentType(mime::APPLICATION_WWW_FORM_URLENCODED) + } + /// A constructor to easily create a `Content-Type: image/jpeg` header. + #[inline] + pub fn jpeg() -> ContentType { + ContentType(mime::IMAGE_JPEG) + } + + /// A constructor to easily create a `Content-Type: image/png` header. + #[inline] + pub fn png() -> ContentType { + ContentType(mime::IMAGE_PNG) + } + + /// A constructor to easily create a `Content-Type: + /// application/octet-stream` header. + #[inline] + pub fn octet_stream() -> ContentType { + ContentType(mime::APPLICATION_OCTET_STREAM) + } +} + +impl Eq for ContentType {} diff --git a/actix-http/src/header/common/date.rs b/actix-http/src/header/common/date.rs new file mode 100644 index 000000000..784100e8d --- /dev/null +++ b/actix-http/src/header/common/date.rs @@ -0,0 +1,42 @@ +use crate::header::{HttpDate, DATE}; +use std::time::SystemTime; + +header! { + /// `Date` header, defined in [RFC7231](http://tools.ietf.org/html/rfc7231#section-7.1.1.2) + /// + /// The `Date` header field represents the date and time at which the + /// message was originated. + /// + /// # ABNF + /// + /// ```text + /// Date = HTTP-date + /// ``` + /// + /// # Example values + /// + /// * `Tue, 15 Nov 1994 08:12:31 GMT` + /// + /// # Example + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::Date; + /// use std::time::SystemTime; + /// + /// let mut builder = Response::Ok(); + /// builder.set(Date(SystemTime::now().into())); + /// ``` + (Date, DATE) => [HttpDate] + + test_date { + test_header!(test1, vec![b"Tue, 15 Nov 1994 08:12:31 GMT"]); + } +} + +impl Date { + /// Create a date instance set to the current system time + pub fn now() -> Date { + Date(SystemTime::now().into()) + } +} diff --git a/actix-http/src/header/common/etag.rs b/actix-http/src/header/common/etag.rs new file mode 100644 index 000000000..325b91cbf --- /dev/null +++ b/actix-http/src/header/common/etag.rs @@ -0,0 +1,96 @@ +use crate::header::{EntityTag, ETAG}; + +header! { + /// `ETag` header, defined in [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.3) + /// + /// The `ETag` header field in a response provides the current entity-tag + /// for the selected representation, as determined at the conclusion of + /// handling the request. An entity-tag is an opaque validator for + /// differentiating between multiple representations of the same + /// resource, regardless of whether those multiple representations are + /// due to resource state changes over time, content negotiation + /// resulting in multiple representations being valid at the same time, + /// or both. An entity-tag consists of an opaque quoted string, possibly + /// prefixed by a weakness indicator. + /// + /// # ABNF + /// + /// ```text + /// ETag = entity-tag + /// ``` + /// + /// # Example values + /// + /// * `"xyzzy"` + /// * `W/"xyzzy"` + /// * `""` + /// + /// # Examples + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::{ETag, EntityTag}; + /// + /// let mut builder = Response::Ok(); + /// builder.set(ETag(EntityTag::new(false, "xyzzy".to_owned()))); + /// ``` + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::{ETag, EntityTag}; + /// + /// let mut builder = Response::Ok(); + /// builder.set(ETag(EntityTag::new(true, "xyzzy".to_owned()))); + /// ``` + (ETag, ETAG) => [EntityTag] + + test_etag { + // From the RFC + test_header!(test1, + vec![b"\"xyzzy\""], + Some(ETag(EntityTag::new(false, "xyzzy".to_owned())))); + test_header!(test2, + vec![b"W/\"xyzzy\""], + Some(ETag(EntityTag::new(true, "xyzzy".to_owned())))); + test_header!(test3, + vec![b"\"\""], + Some(ETag(EntityTag::new(false, "".to_owned())))); + // Own tests + test_header!(test4, + vec![b"\"foobar\""], + Some(ETag(EntityTag::new(false, "foobar".to_owned())))); + test_header!(test5, + vec![b"\"\""], + Some(ETag(EntityTag::new(false, "".to_owned())))); + test_header!(test6, + vec![b"W/\"weak-etag\""], + Some(ETag(EntityTag::new(true, "weak-etag".to_owned())))); + test_header!(test7, + vec![b"W/\"\x65\x62\""], + Some(ETag(EntityTag::new(true, "\u{0065}\u{0062}".to_owned())))); + test_header!(test8, + vec![b"W/\"\""], + Some(ETag(EntityTag::new(true, "".to_owned())))); + test_header!(test9, + vec![b"no-dquotes"], + None::); + test_header!(test10, + vec![b"w/\"the-first-w-is-case-sensitive\""], + None::); + test_header!(test11, + vec![b""], + None::); + test_header!(test12, + vec![b"\"unmatched-dquotes1"], + None::); + test_header!(test13, + vec![b"unmatched-dquotes2\""], + None::); + test_header!(test14, + vec![b"matched-\"dquotes\""], + None::); + test_header!(test15, + vec![b"\""], + None::); + } +} diff --git a/actix-http/src/header/common/expires.rs b/actix-http/src/header/common/expires.rs new file mode 100644 index 000000000..3b9a7873d --- /dev/null +++ b/actix-http/src/header/common/expires.rs @@ -0,0 +1,39 @@ +use crate::header::{HttpDate, EXPIRES}; + +header! { + /// `Expires` header, defined in [RFC7234](http://tools.ietf.org/html/rfc7234#section-5.3) + /// + /// The `Expires` header field gives the date/time after which the + /// response is considered stale. + /// + /// The presence of an Expires field does not imply that the original + /// resource will change or cease to exist at, before, or after that + /// time. + /// + /// # ABNF + /// + /// ```text + /// Expires = HTTP-date + /// ``` + /// + /// # Example values + /// * `Thu, 01 Dec 1994 16:00:00 GMT` + /// + /// # Example + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::Expires; + /// use std::time::{SystemTime, Duration}; + /// + /// let mut builder = Response::Ok(); + /// let expiration = SystemTime::now() + Duration::from_secs(60 * 60 * 24); + /// builder.set(Expires(expiration.into())); + /// ``` + (Expires, EXPIRES) => [HttpDate] + + test_expires { + // Test case from RFC + test_header!(test1, vec![b"Thu, 01 Dec 1994 16:00:00 GMT"]); + } +} diff --git a/actix-http/src/header/common/if_match.rs b/actix-http/src/header/common/if_match.rs new file mode 100644 index 000000000..7e0e9a7e0 --- /dev/null +++ b/actix-http/src/header/common/if_match.rs @@ -0,0 +1,70 @@ +use crate::header::{EntityTag, IF_MATCH}; + +header! { + /// `If-Match` header, defined in + /// [RFC7232](https://tools.ietf.org/html/rfc7232#section-3.1) + /// + /// The `If-Match` header field makes the request method conditional on + /// the recipient origin server either having at least one current + /// representation of the target resource, when the field-value is "*", + /// or having a current representation of the target resource that has an + /// entity-tag matching a member of the list of entity-tags provided in + /// the field-value. + /// + /// An origin server MUST use the strong comparison function when + /// comparing entity-tags for `If-Match`, since the client + /// intends this precondition to prevent the method from being applied if + /// there have been any changes to the representation data. + /// + /// # ABNF + /// + /// ```text + /// If-Match = "*" / 1#entity-tag + /// ``` + /// + /// # Example values + /// + /// * `"xyzzy"` + /// * "xyzzy", "r2d2xxxx", "c3piozzzz" + /// + /// # Examples + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::IfMatch; + /// + /// let mut builder = Response::Ok(); + /// builder.set(IfMatch::Any); + /// ``` + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::{IfMatch, EntityTag}; + /// + /// let mut builder = Response::Ok(); + /// builder.set( + /// IfMatch::Items(vec![ + /// EntityTag::new(false, "xyzzy".to_owned()), + /// EntityTag::new(false, "foobar".to_owned()), + /// EntityTag::new(false, "bazquux".to_owned()), + /// ]) + /// ); + /// ``` + (IfMatch, IF_MATCH) => {Any / (EntityTag)+} + + test_if_match { + test_header!( + test1, + vec![b"\"xyzzy\""], + Some(HeaderField::Items( + vec![EntityTag::new(false, "xyzzy".to_owned())]))); + test_header!( + test2, + vec![b"\"xyzzy\", \"r2d2xxxx\", \"c3piozzzz\""], + Some(HeaderField::Items( + vec![EntityTag::new(false, "xyzzy".to_owned()), + EntityTag::new(false, "r2d2xxxx".to_owned()), + EntityTag::new(false, "c3piozzzz".to_owned())]))); + test_header!(test3, vec![b"*"], Some(IfMatch::Any)); + } +} diff --git a/actix-http/src/header/common/if_modified_since.rs b/actix-http/src/header/common/if_modified_since.rs new file mode 100644 index 000000000..39aca595d --- /dev/null +++ b/actix-http/src/header/common/if_modified_since.rs @@ -0,0 +1,39 @@ +use crate::header::{HttpDate, IF_MODIFIED_SINCE}; + +header! { + /// `If-Modified-Since` header, defined in + /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-3.3) + /// + /// The `If-Modified-Since` header field makes a GET or HEAD request + /// method conditional on the selected representation's modification date + /// being more recent than the date provided in the field-value. + /// Transfer of the selected representation's data is avoided if that + /// data has not changed. + /// + /// # ABNF + /// + /// ```text + /// If-Unmodified-Since = HTTP-date + /// ``` + /// + /// # Example values + /// * `Sat, 29 Oct 1994 19:43:31 GMT` + /// + /// # Example + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::IfModifiedSince; + /// use std::time::{SystemTime, Duration}; + /// + /// let mut builder = Response::Ok(); + /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); + /// builder.set(IfModifiedSince(modified.into())); + /// ``` + (IfModifiedSince, IF_MODIFIED_SINCE) => [HttpDate] + + test_if_modified_since { + // Test case from RFC + test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + } +} diff --git a/actix-http/src/header/common/if_none_match.rs b/actix-http/src/header/common/if_none_match.rs new file mode 100644 index 000000000..7f6ccb137 --- /dev/null +++ b/actix-http/src/header/common/if_none_match.rs @@ -0,0 +1,92 @@ +use crate::header::{EntityTag, IF_NONE_MATCH}; + +header! { + /// `If-None-Match` header, defined in + /// [RFC7232](https://tools.ietf.org/html/rfc7232#section-3.2) + /// + /// The `If-None-Match` header field makes the request method conditional + /// on a recipient cache or origin server either not having any current + /// representation of the target resource, when the field-value is "*", + /// or having a selected representation with an entity-tag that does not + /// match any of those listed in the field-value. + /// + /// A recipient MUST use the weak comparison function when comparing + /// entity-tags for If-None-Match (Section 2.3.2), since weak entity-tags + /// can be used for cache validation even if there have been changes to + /// the representation data. + /// + /// # ABNF + /// + /// ```text + /// If-None-Match = "*" / 1#entity-tag + /// ``` + /// + /// # Example values + /// + /// * `"xyzzy"` + /// * `W/"xyzzy"` + /// * `"xyzzy", "r2d2xxxx", "c3piozzzz"` + /// * `W/"xyzzy", W/"r2d2xxxx", W/"c3piozzzz"` + /// * `*` + /// + /// # Examples + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::IfNoneMatch; + /// + /// let mut builder = Response::Ok(); + /// builder.set(IfNoneMatch::Any); + /// ``` + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::{IfNoneMatch, EntityTag}; + /// + /// let mut builder = Response::Ok(); + /// builder.set( + /// IfNoneMatch::Items(vec![ + /// EntityTag::new(false, "xyzzy".to_owned()), + /// EntityTag::new(false, "foobar".to_owned()), + /// EntityTag::new(false, "bazquux".to_owned()), + /// ]) + /// ); + /// ``` + (IfNoneMatch, IF_NONE_MATCH) => {Any / (EntityTag)+} + + test_if_none_match { + test_header!(test1, vec![b"\"xyzzy\""]); + test_header!(test2, vec![b"W/\"xyzzy\""]); + test_header!(test3, vec![b"\"xyzzy\", \"r2d2xxxx\", \"c3piozzzz\""]); + test_header!(test4, vec![b"W/\"xyzzy\", W/\"r2d2xxxx\", W/\"c3piozzzz\""]); + test_header!(test5, vec![b"*"]); + } +} + +#[cfg(test)] +mod tests { + use super::IfNoneMatch; + use crate::header::{EntityTag, Header, IF_NONE_MATCH}; + use crate::test::TestRequest; + + #[test] + fn test_if_none_match() { + let mut if_none_match: Result; + + let req = TestRequest::with_header(IF_NONE_MATCH, "*").finish(); + if_none_match = Header::parse(&req); + assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Any)); + + let req = + TestRequest::with_header(IF_NONE_MATCH, &b"\"foobar\", W/\"weak-etag\""[..]) + .finish(); + + if_none_match = Header::parse(&req); + let mut entities: Vec = Vec::new(); + let foobar_etag = EntityTag::new(false, "foobar".to_owned()); + let weak_etag = EntityTag::new(true, "weak-etag".to_owned()); + entities.push(foobar_etag); + entities.push(weak_etag); + assert_eq!(if_none_match.ok(), Some(IfNoneMatch::Items(entities))); + } +} diff --git a/actix-http/src/header/common/if_range.rs b/actix-http/src/header/common/if_range.rs new file mode 100644 index 000000000..2140ccbb3 --- /dev/null +++ b/actix-http/src/header/common/if_range.rs @@ -0,0 +1,116 @@ +use std::fmt::{self, Display, Write}; + +use crate::error::ParseError; +use crate::header::{ + self, from_one_raw_str, EntityTag, Header, HeaderName, HeaderValue, HttpDate, + IntoHeaderValue, InvalidHeaderValueBytes, Writer, +}; +use crate::httpmessage::HttpMessage; + +/// `If-Range` header, defined in [RFC7233](http://tools.ietf.org/html/rfc7233#section-3.2) +/// +/// If a client has a partial copy of a representation and wishes to have +/// an up-to-date copy of the entire representation, it could use the +/// Range header field with a conditional GET (using either or both of +/// If-Unmodified-Since and If-Match.) However, if the precondition +/// fails because the representation has been modified, the client would +/// then have to make a second request to obtain the entire current +/// representation. +/// +/// The `If-Range` header field allows a client to \"short-circuit\" the +/// second request. Informally, its meaning is as follows: if the +/// representation is unchanged, send me the part(s) that I am requesting +/// in Range; otherwise, send me the entire representation. +/// +/// # ABNF +/// +/// ```text +/// If-Range = entity-tag / HTTP-date +/// ``` +/// +/// # Example values +/// +/// * `Sat, 29 Oct 1994 19:43:31 GMT` +/// * `\"xyzzy\"` +/// +/// # Examples +/// +/// ```rust +/// use actix_http::Response; +/// use actix_http::http::header::{EntityTag, IfRange}; +/// +/// let mut builder = Response::Ok(); +/// builder.set(IfRange::EntityTag(EntityTag::new( +/// false, +/// "xyzzy".to_owned(), +/// ))); +/// ``` +/// +/// ```rust +/// use actix_http::Response; +/// use actix_http::http::header::IfRange; +/// use std::time::{Duration, SystemTime}; +/// +/// let mut builder = Response::Ok(); +/// let fetched = SystemTime::now() - Duration::from_secs(60 * 60 * 24); +/// builder.set(IfRange::Date(fetched.into())); +/// ``` +#[derive(Clone, Debug, PartialEq)] +pub enum IfRange { + /// The entity-tag the client has of the resource + EntityTag(EntityTag), + /// The date when the client retrieved the resource + Date(HttpDate), +} + +impl Header for IfRange { + fn name() -> HeaderName { + header::IF_RANGE + } + #[inline] + fn parse(msg: &T) -> Result + where + T: HttpMessage, + { + let etag: Result = + from_one_raw_str(msg.headers().get(header::IF_RANGE)); + if let Ok(etag) = etag { + return Ok(IfRange::EntityTag(etag)); + } + let date: Result = + from_one_raw_str(msg.headers().get(header::IF_RANGE)); + if let Ok(date) = date { + return Ok(IfRange::Date(date)); + } + Err(ParseError::Header) + } +} + +impl Display for IfRange { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + IfRange::EntityTag(ref x) => Display::fmt(x, f), + IfRange::Date(ref x) => Display::fmt(x, f), + } + } +} + +impl IntoHeaderValue for IfRange { + type Error = InvalidHeaderValueBytes; + + fn try_into(self) -> Result { + let mut writer = Writer::new(); + let _ = write!(&mut writer, "{}", self); + HeaderValue::from_shared(writer.take()) + } +} + +#[cfg(test)] +mod test_if_range { + use super::IfRange as HeaderField; + use crate::header::*; + use std::str; + test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + test_header!(test2, vec![b"\"xyzzy\""]); + test_header!(test3, vec![b"this-is-invalid"], None::); +} diff --git a/actix-http/src/header/common/if_unmodified_since.rs b/actix-http/src/header/common/if_unmodified_since.rs new file mode 100644 index 000000000..d6c099e64 --- /dev/null +++ b/actix-http/src/header/common/if_unmodified_since.rs @@ -0,0 +1,40 @@ +use crate::header::{HttpDate, IF_UNMODIFIED_SINCE}; + +header! { + /// `If-Unmodified-Since` header, defined in + /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-3.4) + /// + /// The `If-Unmodified-Since` header field makes the request method + /// conditional on the selected representation's last modification date + /// being earlier than or equal to the date provided in the field-value. + /// This field accomplishes the same purpose as If-Match for cases where + /// the user agent does not have an entity-tag for the representation. + /// + /// # ABNF + /// + /// ```text + /// If-Unmodified-Since = HTTP-date + /// ``` + /// + /// # Example values + /// + /// * `Sat, 29 Oct 1994 19:43:31 GMT` + /// + /// # Example + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::IfUnmodifiedSince; + /// use std::time::{SystemTime, Duration}; + /// + /// let mut builder = Response::Ok(); + /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); + /// builder.set(IfUnmodifiedSince(modified.into())); + /// ``` + (IfUnmodifiedSince, IF_UNMODIFIED_SINCE) => [HttpDate] + + test_if_unmodified_since { + // Test case from RFC + test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]); + } +} diff --git a/actix-http/src/header/common/last_modified.rs b/actix-http/src/header/common/last_modified.rs new file mode 100644 index 000000000..cc888ccb0 --- /dev/null +++ b/actix-http/src/header/common/last_modified.rs @@ -0,0 +1,38 @@ +use crate::header::{HttpDate, LAST_MODIFIED}; + +header! { + /// `Last-Modified` header, defined in + /// [RFC7232](http://tools.ietf.org/html/rfc7232#section-2.2) + /// + /// The `Last-Modified` header field in a response provides a timestamp + /// indicating the date and time at which the origin server believes the + /// selected representation was last modified, as determined at the + /// conclusion of handling the request. + /// + /// # ABNF + /// + /// ```text + /// Expires = HTTP-date + /// ``` + /// + /// # Example values + /// + /// * `Sat, 29 Oct 1994 19:43:31 GMT` + /// + /// # Example + /// + /// ```rust + /// use actix_http::Response; + /// use actix_http::http::header::LastModified; + /// use std::time::{SystemTime, Duration}; + /// + /// let mut builder = Response::Ok(); + /// let modified = SystemTime::now() - Duration::from_secs(60 * 60 * 24); + /// builder.set(LastModified(modified.into())); + /// ``` + (LastModified, LAST_MODIFIED) => [HttpDate] + + test_last_modified { + // Test case from RFC + test_header!(test1, vec![b"Sat, 29 Oct 1994 19:43:31 GMT"]);} +} diff --git a/actix-http/src/header/common/mod.rs b/actix-http/src/header/common/mod.rs new file mode 100644 index 000000000..30dfcaa6d --- /dev/null +++ b/actix-http/src/header/common/mod.rs @@ -0,0 +1,352 @@ +//! A Collection of Header implementations for common HTTP Headers. +//! +//! ## Mime +//! +//! Several header fields use MIME values for their contents. Keeping with the +//! strongly-typed theme, the [mime](https://docs.rs/mime) crate +//! is used, such as `ContentType(pub Mime)`. +#![cfg_attr(rustfmt, rustfmt_skip)] + +pub use self::accept_charset::AcceptCharset; +//pub use self::accept_encoding::AcceptEncoding; +pub use self::accept_language::AcceptLanguage; +pub use self::accept::Accept; +pub use self::allow::Allow; +pub use self::cache_control::{CacheControl, CacheDirective}; +pub use self::content_disposition::{ContentDisposition, DispositionType, DispositionParam}; +pub use self::content_language::ContentLanguage; +pub use self::content_range::{ContentRange, ContentRangeSpec}; +pub use self::content_type::ContentType; +pub use self::date::Date; +pub use self::etag::ETag; +pub use self::expires::Expires; +pub use self::if_match::IfMatch; +pub use self::if_modified_since::IfModifiedSince; +pub use self::if_none_match::IfNoneMatch; +pub use self::if_range::IfRange; +pub use self::if_unmodified_since::IfUnmodifiedSince; +pub use self::last_modified::LastModified; +//pub use self::range::{Range, ByteRangeSpec}; + +#[doc(hidden)] +#[macro_export] +macro_rules! __hyper__deref { + ($from:ty => $to:ty) => { + impl ::std::ops::Deref for $from { + type Target = $to; + + #[inline] + fn deref(&self) -> &$to { + &self.0 + } + } + + impl ::std::ops::DerefMut for $from { + #[inline] + fn deref_mut(&mut self) -> &mut $to { + &mut self.0 + } + } + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __hyper__tm { + ($id:ident, $tm:ident{$($tf:item)*}) => { + #[allow(unused_imports)] + #[cfg(test)] + mod $tm{ + use std::str; + use http::Method; + use mime::*; + use $crate::header::*; + use super::$id as HeaderField; + $($tf)* + } + + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! test_header { + ($id:ident, $raw:expr) => { + #[test] + fn $id() { + use $crate::test; + use super::*; + + let raw = $raw; + let a: Vec> = raw.iter().map(|x| x.to_vec()).collect(); + let mut req = test::TestRequest::default(); + for item in a { + req = req.header(HeaderField::name(), item).take(); + } + let req = req.finish(); + let value = HeaderField::parse(&req); + let result = format!("{}", value.unwrap()); + let expected = String::from_utf8(raw[0].to_vec()).unwrap(); + let result_cmp: Vec = result + .to_ascii_lowercase() + .split(' ') + .map(|x| x.to_owned()) + .collect(); + let expected_cmp: Vec = expected + .to_ascii_lowercase() + .split(' ') + .map(|x| x.to_owned()) + .collect(); + assert_eq!(result_cmp.concat(), expected_cmp.concat()); + } + }; + ($id:ident, $raw:expr, $typed:expr) => { + #[test] + fn $id() { + use $crate::test; + + let a: Vec> = $raw.iter().map(|x| x.to_vec()).collect(); + let mut req = test::TestRequest::default(); + for item in a { + req.header(HeaderField::name(), item); + } + let req = req.finish(); + let val = HeaderField::parse(&req); + let typed: Option = $typed; + // Test parsing + assert_eq!(val.ok(), typed); + // Test formatting + if typed.is_some() { + let raw = &($raw)[..]; + let mut iter = raw.iter().map(|b|str::from_utf8(&b[..]).unwrap()); + let mut joined = String::new(); + joined.push_str(iter.next().unwrap()); + for s in iter { + joined.push_str(", "); + joined.push_str(s); + } + assert_eq!(format!("{}", typed.unwrap()), joined); + } + } + } +} + +#[macro_export] +macro_rules! header { + // $a:meta: Attributes associated with the header item (usually docs) + // $id:ident: Identifier of the header + // $n:expr: Lowercase name of the header + // $nn:expr: Nice name of the header + + // List header, zero or more items + ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)*) => { + $(#[$a])* + #[derive(Clone, Debug, PartialEq)] + pub struct $id(pub Vec<$item>); + __hyper__deref!($id => Vec<$item>); + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + #[inline] + fn parse(msg: &T) -> Result + where T: $crate::HttpMessage + { + $crate::http::header::from_comma_delimited( + msg.headers().get_all(Self::name())).map($id) + } + } + impl std::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter) -> ::std::fmt::Result { + $crate::http::header::fmt_comma_delimited(f, &self.0[..]) + } + } + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; + + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + use std::fmt::Write; + let mut writer = $crate::http::header::Writer::new(); + let _ = write!(&mut writer, "{}", self); + $crate::http::header::HeaderValue::from_shared(writer.take()) + } + } + }; + // List header, one or more items + ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)+) => { + $(#[$a])* + #[derive(Clone, Debug, PartialEq)] + pub struct $id(pub Vec<$item>); + __hyper__deref!($id => Vec<$item>); + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + #[inline] + fn parse(msg: &T) -> Result + where T: $crate::HttpMessage + { + $crate::http::header::from_comma_delimited( + msg.headers().get_all(Self::name())).map($id) + } + } + impl std::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + $crate::http::header::fmt_comma_delimited(f, &self.0[..]) + } + } + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; + + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + use std::fmt::Write; + let mut writer = $crate::http::header::Writer::new(); + let _ = write!(&mut writer, "{}", self); + $crate::http::header::HeaderValue::from_shared(writer.take()) + } + } + }; + // Single value header + ($(#[$a:meta])*($id:ident, $name:expr) => [$value:ty]) => { + $(#[$a])* + #[derive(Clone, Debug, PartialEq)] + pub struct $id(pub $value); + __hyper__deref!($id => $value); + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + #[inline] + fn parse(msg: &T) -> Result + where T: $crate::HttpMessage + { + $crate::http::header::from_one_raw_str( + msg.headers().get(Self::name())).map($id) + } + } + impl std::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } + } + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; + + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + self.0.try_into() + } + } + }; + // List header, one or more items with "*" option + ($(#[$a:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+}) => { + $(#[$a])* + #[derive(Clone, Debug, PartialEq)] + pub enum $id { + /// Any value is a match + Any, + /// Only the listed items are a match + Items(Vec<$item>), + } + impl $crate::http::header::Header for $id { + #[inline] + fn name() -> $crate::http::header::HeaderName { + $name + } + #[inline] + fn parse(msg: &T) -> Result + where T: $crate::HttpMessage + { + let any = msg.headers().get(Self::name()).and_then(|hdr| { + hdr.to_str().ok().and_then(|hdr| Some(hdr.trim() == "*"))}); + + if let Some(true) = any { + Ok($id::Any) + } else { + Ok($id::Items( + $crate::http::header::from_comma_delimited( + msg.headers().get_all(Self::name()))?)) + } + } + } + impl std::fmt::Display for $id { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match *self { + $id::Any => f.write_str("*"), + $id::Items(ref fields) => $crate::http::header::fmt_comma_delimited( + f, &fields[..]) + } + } + } + impl $crate::http::header::IntoHeaderValue for $id { + type Error = $crate::http::header::InvalidHeaderValueBytes; + + fn try_into(self) -> Result<$crate::http::header::HeaderValue, Self::Error> { + use std::fmt::Write; + let mut writer = $crate::http::header::Writer::new(); + let _ = write!(&mut writer, "{}", self); + $crate::http::header::HeaderValue::from_shared(writer.take()) + } + } + }; + + // optional test module + ($(#[$a:meta])*($id:ident, $name:expr) => ($item:ty)* $tm:ident{$($tf:item)*}) => { + header! { + $(#[$a])* + ($id, $name) => ($item)* + } + + __hyper__tm! { $id, $tm { $($tf)* }} + }; + ($(#[$a:meta])*($id:ident, $n:expr) => ($item:ty)+ $tm:ident{$($tf:item)*}) => { + header! { + $(#[$a])* + ($id, $n) => ($item)+ + } + + __hyper__tm! { $id, $tm { $($tf)* }} + }; + ($(#[$a:meta])*($id:ident, $name:expr) => [$item:ty] $tm:ident{$($tf:item)*}) => { + header! { + $(#[$a])* ($id, $name) => [$item] + } + + __hyper__tm! { $id, $tm { $($tf)* }} + }; + ($(#[$a:meta])*($id:ident, $name:expr) => {Any / ($item:ty)+} $tm:ident{$($tf:item)*}) => { + header! { + $(#[$a])* + ($id, $name) => {Any / ($item)+} + } + + __hyper__tm! { $id, $tm { $($tf)* }} + }; +} + + +mod accept_charset; +//mod accept_encoding; +mod accept_language; +mod accept; +mod allow; +mod cache_control; +mod content_disposition; +mod content_language; +mod content_range; +mod content_type; +mod date; +mod etag; +mod expires; +mod if_match; +mod if_modified_since; +mod if_none_match; +mod if_range; +mod if_unmodified_since; +mod last_modified; diff --git a/actix-http/src/header/common/range.rs b/actix-http/src/header/common/range.rs new file mode 100644 index 000000000..71718fc7a --- /dev/null +++ b/actix-http/src/header/common/range.rs @@ -0,0 +1,434 @@ +use std::fmt::{self, Display}; +use std::str::FromStr; + +use header::parsing::from_one_raw_str; +use header::{Header, Raw}; + +/// `Range` header, defined in [RFC7233](https://tools.ietf.org/html/rfc7233#section-3.1) +/// +/// The "Range" header field on a GET request modifies the method +/// semantics to request transfer of only one or more subranges of the +/// selected representation data, rather than the entire selected +/// representation data. +/// +/// # ABNF +/// +/// ```text +/// Range = byte-ranges-specifier / other-ranges-specifier +/// other-ranges-specifier = other-range-unit "=" other-range-set +/// other-range-set = 1*VCHAR +/// +/// bytes-unit = "bytes" +/// +/// byte-ranges-specifier = bytes-unit "=" byte-range-set +/// byte-range-set = 1#(byte-range-spec / suffix-byte-range-spec) +/// byte-range-spec = first-byte-pos "-" [last-byte-pos] +/// first-byte-pos = 1*DIGIT +/// last-byte-pos = 1*DIGIT +/// ``` +/// +/// # Example values +/// +/// * `bytes=1000-` +/// * `bytes=-2000` +/// * `bytes=0-1,30-40` +/// * `bytes=0-10,20-90,-100` +/// * `custom_unit=0-123` +/// * `custom_unit=xxx-yyy` +/// +/// # Examples +/// +/// ``` +/// use hyper::header::{Headers, Range, ByteRangeSpec}; +/// +/// let mut headers = Headers::new(); +/// headers.set(Range::Bytes( +/// vec![ByteRangeSpec::FromTo(1, 100), ByteRangeSpec::AllFrom(200)] +/// )); +/// +/// headers.clear(); +/// headers.set(Range::Unregistered("letters".to_owned(), "a-f".to_owned())); +/// ``` +/// +/// ``` +/// use hyper::header::{Headers, Range}; +/// +/// let mut headers = Headers::new(); +/// headers.set(Range::bytes(1, 100)); +/// +/// headers.clear(); +/// headers.set(Range::bytes_multi(vec![(1, 100), (200, 300)])); +/// ``` +#[derive(PartialEq, Clone, Debug)] +pub enum Range { + /// Byte range + Bytes(Vec), + /// Custom range, with unit not registered at IANA + /// (`other-range-unit`: String , `other-range-set`: String) + Unregistered(String, String), +} + +/// Each `Range::Bytes` header can contain one or more `ByteRangeSpecs`. +/// Each `ByteRangeSpec` defines a range of bytes to fetch +#[derive(PartialEq, Clone, Debug)] +pub enum ByteRangeSpec { + /// Get all bytes between x and y ("x-y") + FromTo(u64, u64), + /// Get all bytes starting from x ("x-") + AllFrom(u64), + /// Get last x bytes ("-x") + Last(u64), +} + +impl ByteRangeSpec { + /// Given the full length of the entity, attempt to normalize the byte range + /// into an satisfiable end-inclusive (from, to) range. + /// + /// The resulting range is guaranteed to be a satisfiable range within the + /// bounds of `0 <= from <= to < full_length`. + /// + /// If the byte range is deemed unsatisfiable, `None` is returned. + /// An unsatisfiable range is generally cause for a server to either reject + /// the client request with a `416 Range Not Satisfiable` status code, or to + /// simply ignore the range header and serve the full entity using a `200 + /// OK` status code. + /// + /// This function closely follows [RFC 7233][1] section 2.1. + /// As such, it considers ranges to be satisfiable if they meet the + /// following conditions: + /// + /// > If a valid byte-range-set includes at least one byte-range-spec with + /// a first-byte-pos that is less than the current length of the + /// representation, or at least one suffix-byte-range-spec with a + /// non-zero suffix-length, then the byte-range-set is satisfiable. + /// Otherwise, the byte-range-set is unsatisfiable. + /// + /// The function also computes remainder ranges based on the RFC: + /// + /// > If the last-byte-pos value is + /// absent, or if the value is greater than or equal to the current + /// length of the representation data, the byte range is interpreted as + /// the remainder of the representation (i.e., the server replaces the + /// value of last-byte-pos with a value that is one less than the current + /// length of the selected representation). + /// + /// [1]: https://tools.ietf.org/html/rfc7233 + pub fn to_satisfiable_range(&self, full_length: u64) -> Option<(u64, u64)> { + // If the full length is zero, there is no satisfiable end-inclusive range. + if full_length == 0 { + return None; + } + match self { + &ByteRangeSpec::FromTo(from, to) => { + if from < full_length && from <= to { + Some((from, ::std::cmp::min(to, full_length - 1))) + } else { + None + } + } + &ByteRangeSpec::AllFrom(from) => { + if from < full_length { + Some((from, full_length - 1)) + } else { + None + } + } + &ByteRangeSpec::Last(last) => { + if last > 0 { + // From the RFC: If the selected representation is shorter + // than the specified suffix-length, + // the entire representation is used. + if last > full_length { + Some((0, full_length - 1)) + } else { + Some((full_length - last, full_length - 1)) + } + } else { + None + } + } + } + } +} + +impl Range { + /// Get the most common byte range header ("bytes=from-to") + pub fn bytes(from: u64, to: u64) -> Range { + Range::Bytes(vec![ByteRangeSpec::FromTo(from, to)]) + } + + /// Get byte range header with multiple subranges + /// ("bytes=from1-to1,from2-to2,fromX-toX") + pub fn bytes_multi(ranges: Vec<(u64, u64)>) -> Range { + Range::Bytes( + ranges + .iter() + .map(|r| ByteRangeSpec::FromTo(r.0, r.1)) + .collect(), + ) + } +} + +impl fmt::Display for ByteRangeSpec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ByteRangeSpec::FromTo(from, to) => write!(f, "{}-{}", from, to), + ByteRangeSpec::Last(pos) => write!(f, "-{}", pos), + ByteRangeSpec::AllFrom(pos) => write!(f, "{}-", pos), + } + } +} + +impl fmt::Display for Range { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Range::Bytes(ref ranges) => { + try!(write!(f, "bytes=")); + + for (i, range) in ranges.iter().enumerate() { + if i != 0 { + try!(f.write_str(",")); + } + try!(Display::fmt(range, f)); + } + Ok(()) + } + Range::Unregistered(ref unit, ref range_str) => { + write!(f, "{}={}", unit, range_str) + } + } + } +} + +impl FromStr for Range { + type Err = ::Error; + + fn from_str(s: &str) -> ::Result { + let mut iter = s.splitn(2, '='); + + match (iter.next(), iter.next()) { + (Some("bytes"), Some(ranges)) => { + let ranges = from_comma_delimited(ranges); + if ranges.is_empty() { + return Err(::Error::Header); + } + Ok(Range::Bytes(ranges)) + } + (Some(unit), Some(range_str)) if unit != "" && range_str != "" => Ok( + Range::Unregistered(unit.to_owned(), range_str.to_owned()), + ), + _ => Err(::Error::Header), + } + } +} + +impl FromStr for ByteRangeSpec { + type Err = ::Error; + + fn from_str(s: &str) -> ::Result { + let mut parts = s.splitn(2, '-'); + + match (parts.next(), parts.next()) { + (Some(""), Some(end)) => end.parse() + .or(Err(::Error::Header)) + .map(ByteRangeSpec::Last), + (Some(start), Some("")) => start + .parse() + .or(Err(::Error::Header)) + .map(ByteRangeSpec::AllFrom), + (Some(start), Some(end)) => match (start.parse(), end.parse()) { + (Ok(start), Ok(end)) if start <= end => { + Ok(ByteRangeSpec::FromTo(start, end)) + } + _ => Err(::Error::Header), + }, + _ => Err(::Error::Header), + } + } +} + +fn from_comma_delimited(s: &str) -> Vec { + s.split(',') + .filter_map(|x| match x.trim() { + "" => None, + y => Some(y), + }) + .filter_map(|x| x.parse().ok()) + .collect() +} + +impl Header for Range { + fn header_name() -> &'static str { + static NAME: &'static str = "Range"; + NAME + } + + fn parse_header(raw: &Raw) -> ::Result { + from_one_raw_str(raw) + } + + fn fmt_header(&self, f: &mut ::header::Formatter) -> fmt::Result { + f.fmt_line(self) + } +} + +#[test] +fn test_parse_bytes_range_valid() { + let r: Range = Header::parse_header(&"bytes=1-100".into()).unwrap(); + let r2: Range = Header::parse_header(&"bytes=1-100,-".into()).unwrap(); + let r3 = Range::bytes(1, 100); + assert_eq!(r, r2); + assert_eq!(r2, r3); + + let r: Range = Header::parse_header(&"bytes=1-100,200-".into()).unwrap(); + let r2: Range = + Header::parse_header(&"bytes= 1-100 , 101-xxx, 200- ".into()).unwrap(); + let r3 = Range::Bytes(vec![ + ByteRangeSpec::FromTo(1, 100), + ByteRangeSpec::AllFrom(200), + ]); + assert_eq!(r, r2); + assert_eq!(r2, r3); + + let r: Range = Header::parse_header(&"bytes=1-100,-100".into()).unwrap(); + let r2: Range = Header::parse_header(&"bytes=1-100, ,,-100".into()).unwrap(); + let r3 = Range::Bytes(vec![ + ByteRangeSpec::FromTo(1, 100), + ByteRangeSpec::Last(100), + ]); + assert_eq!(r, r2); + assert_eq!(r2, r3); + + let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); + assert_eq!(r, r2); +} + +#[test] +fn test_parse_unregistered_range_valid() { + let r: Range = Header::parse_header(&"custom=1-100,-100".into()).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "1-100,-100".to_owned()); + assert_eq!(r, r2); + + let r: Range = Header::parse_header(&"custom=abcd".into()).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "abcd".to_owned()); + assert_eq!(r, r2); + + let r: Range = Header::parse_header(&"custom=xxx-yyy".into()).unwrap(); + let r2 = Range::Unregistered("custom".to_owned(), "xxx-yyy".to_owned()); + assert_eq!(r, r2); +} + +#[test] +fn test_parse_invalid() { + let r: ::Result = Header::parse_header(&"bytes=1-a,-".into()); + assert_eq!(r.ok(), None); + + let r: ::Result = Header::parse_header(&"bytes=1-2-3".into()); + assert_eq!(r.ok(), None); + + let r: ::Result = Header::parse_header(&"abc".into()); + assert_eq!(r.ok(), None); + + let r: ::Result = Header::parse_header(&"bytes=1-100=".into()); + assert_eq!(r.ok(), None); + + let r: ::Result = Header::parse_header(&"bytes=".into()); + assert_eq!(r.ok(), None); + + let r: ::Result = Header::parse_header(&"custom=".into()); + assert_eq!(r.ok(), None); + + let r: ::Result = Header::parse_header(&"=1-100".into()); + assert_eq!(r.ok(), None); +} + +#[test] +fn test_fmt() { + use header::Headers; + + let mut headers = Headers::new(); + + headers.set(Range::Bytes(vec![ + ByteRangeSpec::FromTo(0, 1000), + ByteRangeSpec::AllFrom(2000), + ])); + assert_eq!(&headers.to_string(), "Range: bytes=0-1000,2000-\r\n"); + + headers.clear(); + headers.set(Range::Bytes(vec![])); + + assert_eq!(&headers.to_string(), "Range: bytes=\r\n"); + + headers.clear(); + headers.set(Range::Unregistered( + "custom".to_owned(), + "1-xxx".to_owned(), + )); + + assert_eq!(&headers.to_string(), "Range: custom=1-xxx\r\n"); +} + +#[test] +fn test_byte_range_spec_to_satisfiable_range() { + assert_eq!( + Some((0, 0)), + ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(3) + ); + assert_eq!( + Some((1, 2)), + ByteRangeSpec::FromTo(1, 2).to_satisfiable_range(3) + ); + assert_eq!( + Some((1, 2)), + ByteRangeSpec::FromTo(1, 5).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::FromTo(3, 3).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::FromTo(2, 1).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::FromTo(0, 0).to_satisfiable_range(0) + ); + + assert_eq!( + Some((0, 2)), + ByteRangeSpec::AllFrom(0).to_satisfiable_range(3) + ); + assert_eq!( + Some((2, 2)), + ByteRangeSpec::AllFrom(2).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::AllFrom(3).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::AllFrom(5).to_satisfiable_range(3) + ); + assert_eq!( + None, + ByteRangeSpec::AllFrom(0).to_satisfiable_range(0) + ); + + assert_eq!( + Some((1, 2)), + ByteRangeSpec::Last(2).to_satisfiable_range(3) + ); + assert_eq!( + Some((2, 2)), + ByteRangeSpec::Last(1).to_satisfiable_range(3) + ); + assert_eq!( + Some((0, 2)), + ByteRangeSpec::Last(5).to_satisfiable_range(3) + ); + assert_eq!(None, ByteRangeSpec::Last(0).to_satisfiable_range(3)); + assert_eq!(None, ByteRangeSpec::Last(2).to_satisfiable_range(0)); +} diff --git a/actix-http/src/header/mod.rs b/actix-http/src/header/mod.rs new file mode 100644 index 000000000..1ef1bd198 --- /dev/null +++ b/actix-http/src/header/mod.rs @@ -0,0 +1,465 @@ +//! Various http headers +// This is mostly copy of [hyper](https://github.com/hyperium/hyper/tree/master/src/header) + +use std::{fmt, str::FromStr}; + +use bytes::{Bytes, BytesMut}; +use http::header::GetAll; +use http::Error as HttpError; +use mime::Mime; + +pub use http::header::*; + +use crate::error::ParseError; +use crate::httpmessage::HttpMessage; + +mod common; +mod shared; +#[doc(hidden)] +pub use self::common::*; +#[doc(hidden)] +pub use self::shared::*; + +#[doc(hidden)] +/// A trait for any object that will represent a header field and value. +pub trait Header +where + Self: IntoHeaderValue, +{ + /// Returns the name of the header field + fn name() -> HeaderName; + + /// Parse a header + fn parse(msg: &T) -> Result; +} + +#[doc(hidden)] +/// A trait for any object that can be Converted to a `HeaderValue` +pub trait IntoHeaderValue: Sized { + /// The type returned in the event of a conversion error. + type Error: Into; + + /// Try to convert value to a Header value. + fn try_into(self) -> Result; +} + +impl IntoHeaderValue for HeaderValue { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into(self) -> Result { + Ok(self) + } +} + +impl<'a> IntoHeaderValue for &'a str { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into(self) -> Result { + self.parse() + } +} + +impl<'a> IntoHeaderValue for &'a [u8] { + type Error = InvalidHeaderValue; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_bytes(self) + } +} + +impl IntoHeaderValue for Bytes { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(self) + } +} + +impl IntoHeaderValue for Vec { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(Bytes::from(self)) + } +} + +impl IntoHeaderValue for String { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(Bytes::from(self)) + } +} + +impl IntoHeaderValue for Mime { + type Error = InvalidHeaderValueBytes; + + #[inline] + fn try_into(self) -> Result { + HeaderValue::from_shared(Bytes::from(format!("{}", self))) + } +} + +/// Represents supported types of content encodings +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum ContentEncoding { + /// Automatically select encoding based on encoding negotiation + Auto, + /// A format using the Brotli algorithm + Br, + /// A format using the zlib structure with deflate algorithm + Deflate, + /// Gzip algorithm + Gzip, + /// Indicates the identity function (i.e. no compression, nor modification) + Identity, +} + +impl ContentEncoding { + #[inline] + /// Is the content compressed? + pub fn is_compression(self) -> bool { + match self { + ContentEncoding::Identity | ContentEncoding::Auto => false, + _ => true, + } + } + + #[inline] + /// Convert content encoding to string + pub fn as_str(self) -> &'static str { + match self { + ContentEncoding::Br => "br", + ContentEncoding::Gzip => "gzip", + ContentEncoding::Deflate => "deflate", + ContentEncoding::Identity | ContentEncoding::Auto => "identity", + } + } + + #[inline] + /// default quality value + pub fn quality(self) -> f64 { + match self { + ContentEncoding::Br => 1.1, + ContentEncoding::Gzip => 1.0, + ContentEncoding::Deflate => 0.9, + ContentEncoding::Identity | ContentEncoding::Auto => 0.1, + } + } +} + +impl<'a> From<&'a str> for ContentEncoding { + fn from(s: &'a str) -> ContentEncoding { + let s = s.trim(); + + if s.eq_ignore_ascii_case("br") { + ContentEncoding::Br + } else if s.eq_ignore_ascii_case("gzip") { + ContentEncoding::Gzip + } else if s.eq_ignore_ascii_case("deflate") { + ContentEncoding::Deflate + } else { + ContentEncoding::Identity + } + } +} + +#[doc(hidden)] +pub(crate) struct Writer { + buf: BytesMut, +} + +impl Writer { + fn new() -> Writer { + Writer { + buf: BytesMut::new(), + } + } + fn take(&mut self) -> Bytes { + self.buf.take().freeze() + } +} + +impl fmt::Write for Writer { + #[inline] + fn write_str(&mut self, s: &str) -> fmt::Result { + self.buf.extend_from_slice(s.as_bytes()); + Ok(()) + } + + #[inline] + fn write_fmt(&mut self, args: fmt::Arguments) -> fmt::Result { + fmt::write(self, args) + } +} + +#[inline] +#[doc(hidden)] +/// Reads a comma-delimited raw header into a Vec. +pub fn from_comma_delimited( + all: GetAll, +) -> Result, ParseError> { + let mut result = Vec::new(); + for h in all { + let s = h.to_str().map_err(|_| ParseError::Header)?; + result.extend( + s.split(',') + .filter_map(|x| match x.trim() { + "" => None, + y => Some(y), + }) + .filter_map(|x| x.trim().parse().ok()), + ) + } + Ok(result) +} + +#[inline] +#[doc(hidden)] +/// Reads a single string when parsing a header. +pub fn from_one_raw_str(val: Option<&HeaderValue>) -> Result { + if let Some(line) = val { + let line = line.to_str().map_err(|_| ParseError::Header)?; + if !line.is_empty() { + return T::from_str(line).or(Err(ParseError::Header)); + } + } + Err(ParseError::Header) +} + +#[inline] +#[doc(hidden)] +/// Format an array into a comma-delimited string. +pub fn fmt_comma_delimited(f: &mut fmt::Formatter, parts: &[T]) -> fmt::Result +where + T: fmt::Display, +{ + let mut iter = parts.iter(); + if let Some(part) = iter.next() { + fmt::Display::fmt(part, f)?; + } + for part in iter { + f.write_str(", ")?; + fmt::Display::fmt(part, f)?; + } + Ok(()) +} + +// From hyper v0.11.27 src/header/parsing.rs + +/// The value part of an extended parameter consisting of three parts: +/// the REQUIRED character set name (`charset`), the OPTIONAL language information (`language_tag`), +/// and a character sequence representing the actual value (`value`), separated by single quote +/// characters. It is defined in [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +#[derive(Clone, Debug, PartialEq)] +pub struct ExtendedValue { + /// The character set that is used to encode the `value` to a string. + pub charset: Charset, + /// The human language details of the `value`, if available. + pub language_tag: Option, + /// The parameter value, as expressed in octets. + pub value: Vec, +} + +/// Parses extended header parameter values (`ext-value`), as defined in +/// [RFC 5987](https://tools.ietf.org/html/rfc5987#section-3.2). +/// +/// Extended values are denoted by parameter names that end with `*`. +/// +/// ## ABNF +/// +/// ```text +/// ext-value = charset "'" [ language ] "'" value-chars +/// ; like RFC 2231's +/// ; (see [RFC2231], Section 7) +/// +/// charset = "UTF-8" / "ISO-8859-1" / mime-charset +/// +/// mime-charset = 1*mime-charsetc +/// mime-charsetc = ALPHA / DIGIT +/// / "!" / "#" / "$" / "%" / "&" +/// / "+" / "-" / "^" / "_" / "`" +/// / "{" / "}" / "~" +/// ; as in Section 2.3 of [RFC2978] +/// ; except that the single quote is not included +/// ; SHOULD be registered in the IANA charset registry +/// +/// language = +/// +/// value-chars = *( pct-encoded / attr-char ) +/// +/// pct-encoded = "%" HEXDIG HEXDIG +/// ; see [RFC3986], Section 2.1 +/// +/// attr-char = ALPHA / DIGIT +/// / "!" / "#" / "$" / "&" / "+" / "-" / "." +/// / "^" / "_" / "`" / "|" / "~" +/// ; token except ( "*" / "'" / "%" ) +/// ``` +pub fn parse_extended_value( + val: &str, +) -> Result { + // Break into three pieces separated by the single-quote character + let mut parts = val.splitn(3, '\''); + + // Interpret the first piece as a Charset + let charset: Charset = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some(n) => FromStr::from_str(n).map_err(|_| crate::error::ParseError::Header)?, + }; + + // Interpret the second piece as a language tag + let language_tag: Option = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some("") => None, + Some(s) => match s.parse() { + Ok(lt) => Some(lt), + Err(_) => return Err(crate::error::ParseError::Header), + }, + }; + + // Interpret the third piece as a sequence of value characters + let value: Vec = match parts.next() { + None => return Err(crate::error::ParseError::Header), + Some(v) => percent_encoding::percent_decode(v.as_bytes()).collect(), + }; + + Ok(ExtendedValue { + value, + charset, + language_tag, + }) +} + +impl fmt::Display for ExtendedValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let encoded_value = percent_encoding::percent_encode( + &self.value[..], + self::percent_encoding_http::HTTP_VALUE, + ); + if let Some(ref lang) = self.language_tag { + write!(f, "{}'{}'{}", self.charset, lang, encoded_value) + } else { + write!(f, "{}''{}", self.charset, encoded_value) + } + } +} + +/// Percent encode a sequence of bytes with a character set defined in +/// [https://tools.ietf.org/html/rfc5987#section-3.2][url] +/// +/// [url]: https://tools.ietf.org/html/rfc5987#section-3.2 +pub fn http_percent_encode(f: &mut fmt::Formatter, bytes: &[u8]) -> fmt::Result { + let encoded = + percent_encoding::percent_encode(bytes, self::percent_encoding_http::HTTP_VALUE); + fmt::Display::fmt(&encoded, f) +} + +mod percent_encoding_http { + use percent_encoding::{self, define_encode_set}; + + // internal module because macro is hard-coded to make a public item + // but we don't want to public export this item + define_encode_set! { + // This encode set is used for HTTP header values and is defined at + // https://tools.ietf.org/html/rfc5987#section-3.2 + pub HTTP_VALUE = [percent_encoding::SIMPLE_ENCODE_SET] | { + ' ', '"', '%', '\'', '(', ')', '*', ',', '/', ':', ';', '<', '-', '>', '?', + '[', '\\', ']', '{', '}' + } + } +} + +#[cfg(test)] +mod tests { + use super::shared::Charset; + use super::{parse_extended_value, ExtendedValue}; + use language_tags::LanguageTag; + + #[test] + fn test_parse_extended_value_with_encoding_and_language_tag() { + let expected_language_tag = "en".parse::().unwrap(); + // RFC 5987, Section 3.2.2 + // Extended notation, using the Unicode character U+00A3 (POUND SIGN) + let result = parse_extended_value("iso-8859-1'en'%A3%20rates"); + assert!(result.is_ok()); + let extended_value = result.unwrap(); + assert_eq!(Charset::Iso_8859_1, extended_value.charset); + assert!(extended_value.language_tag.is_some()); + assert_eq!(expected_language_tag, extended_value.language_tag.unwrap()); + assert_eq!( + vec![163, b' ', b'r', b'a', b't', b'e', b's'], + extended_value.value + ); + } + + #[test] + fn test_parse_extended_value_with_encoding() { + // RFC 5987, Section 3.2.2 + // Extended notation, using the Unicode characters U+00A3 (POUND SIGN) + // and U+20AC (EURO SIGN) + let result = parse_extended_value("UTF-8''%c2%a3%20and%20%e2%82%ac%20rates"); + assert!(result.is_ok()); + let extended_value = result.unwrap(); + assert_eq!(Charset::Ext("UTF-8".to_string()), extended_value.charset); + assert!(extended_value.language_tag.is_none()); + assert_eq!( + vec![ + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', + b't', b'e', b's', + ], + extended_value.value + ); + } + + #[test] + fn test_parse_extended_value_missing_language_tag_and_encoding() { + // From: https://greenbytes.de/tech/tc2231/#attwithfn2231quot2 + let result = parse_extended_value("foo%20bar.html"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_extended_value_partially_formatted() { + let result = parse_extended_value("UTF-8'missing third part"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_extended_value_partially_formatted_blank() { + let result = parse_extended_value("blank second part'"); + assert!(result.is_err()); + } + + #[test] + fn test_fmt_extended_value_with_encoding_and_language_tag() { + let extended_value = ExtendedValue { + charset: Charset::Iso_8859_1, + language_tag: Some("en".parse().expect("Could not parse language tag")), + value: vec![163, b' ', b'r', b'a', b't', b'e', b's'], + }; + assert_eq!("ISO-8859-1'en'%A3%20rates", format!("{}", extended_value)); + } + + #[test] + fn test_fmt_extended_value_with_encoding() { + let extended_value = ExtendedValue { + charset: Charset::Ext("UTF-8".to_string()), + language_tag: None, + value: vec![ + 194, 163, b' ', b'a', b'n', b'd', b' ', 226, 130, 172, b' ', b'r', b'a', + b't', b'e', b's', + ], + }; + assert_eq!( + "UTF-8''%C2%A3%20and%20%E2%82%AC%20rates", + format!("{}", extended_value) + ); + } +} diff --git a/actix-http/src/header/shared/charset.rs b/actix-http/src/header/shared/charset.rs new file mode 100644 index 000000000..ec3fe3854 --- /dev/null +++ b/actix-http/src/header/shared/charset.rs @@ -0,0 +1,153 @@ +use std::fmt::{self, Display}; +use std::str::FromStr; + +use self::Charset::*; + +/// A Mime charset. +/// +/// The string representation is normalized to upper case. +/// +/// See [http://www.iana.org/assignments/character-sets/character-sets.xhtml][url]. +/// +/// [url]: http://www.iana.org/assignments/character-sets/character-sets.xhtml +#[derive(Clone, Debug, PartialEq)] +#[allow(non_camel_case_types)] +pub enum Charset { + /// US ASCII + Us_Ascii, + /// ISO-8859-1 + Iso_8859_1, + /// ISO-8859-2 + Iso_8859_2, + /// ISO-8859-3 + Iso_8859_3, + /// ISO-8859-4 + Iso_8859_4, + /// ISO-8859-5 + Iso_8859_5, + /// ISO-8859-6 + Iso_8859_6, + /// ISO-8859-7 + Iso_8859_7, + /// ISO-8859-8 + Iso_8859_8, + /// ISO-8859-9 + Iso_8859_9, + /// ISO-8859-10 + Iso_8859_10, + /// Shift_JIS + Shift_Jis, + /// EUC-JP + Euc_Jp, + /// ISO-2022-KR + Iso_2022_Kr, + /// EUC-KR + Euc_Kr, + /// ISO-2022-JP + Iso_2022_Jp, + /// ISO-2022-JP-2 + Iso_2022_Jp_2, + /// ISO-8859-6-E + Iso_8859_6_E, + /// ISO-8859-6-I + Iso_8859_6_I, + /// ISO-8859-8-E + Iso_8859_8_E, + /// ISO-8859-8-I + Iso_8859_8_I, + /// GB2312 + Gb2312, + /// Big5 + Big5, + /// KOI8-R + Koi8_R, + /// An arbitrary charset specified as a string + Ext(String), +} + +impl Charset { + fn label(&self) -> &str { + match *self { + Us_Ascii => "US-ASCII", + Iso_8859_1 => "ISO-8859-1", + Iso_8859_2 => "ISO-8859-2", + Iso_8859_3 => "ISO-8859-3", + Iso_8859_4 => "ISO-8859-4", + Iso_8859_5 => "ISO-8859-5", + Iso_8859_6 => "ISO-8859-6", + Iso_8859_7 => "ISO-8859-7", + Iso_8859_8 => "ISO-8859-8", + Iso_8859_9 => "ISO-8859-9", + Iso_8859_10 => "ISO-8859-10", + Shift_Jis => "Shift-JIS", + Euc_Jp => "EUC-JP", + Iso_2022_Kr => "ISO-2022-KR", + Euc_Kr => "EUC-KR", + Iso_2022_Jp => "ISO-2022-JP", + Iso_2022_Jp_2 => "ISO-2022-JP-2", + Iso_8859_6_E => "ISO-8859-6-E", + Iso_8859_6_I => "ISO-8859-6-I", + Iso_8859_8_E => "ISO-8859-8-E", + Iso_8859_8_I => "ISO-8859-8-I", + Gb2312 => "GB2312", + Big5 => "big5", + Koi8_R => "KOI8-R", + Ext(ref s) => s, + } + } +} + +impl Display for Charset { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.label()) + } +} + +impl FromStr for Charset { + type Err = crate::Error; + + fn from_str(s: &str) -> crate::Result { + Ok(match s.to_ascii_uppercase().as_ref() { + "US-ASCII" => Us_Ascii, + "ISO-8859-1" => Iso_8859_1, + "ISO-8859-2" => Iso_8859_2, + "ISO-8859-3" => Iso_8859_3, + "ISO-8859-4" => Iso_8859_4, + "ISO-8859-5" => Iso_8859_5, + "ISO-8859-6" => Iso_8859_6, + "ISO-8859-7" => Iso_8859_7, + "ISO-8859-8" => Iso_8859_8, + "ISO-8859-9" => Iso_8859_9, + "ISO-8859-10" => Iso_8859_10, + "SHIFT-JIS" => Shift_Jis, + "EUC-JP" => Euc_Jp, + "ISO-2022-KR" => Iso_2022_Kr, + "EUC-KR" => Euc_Kr, + "ISO-2022-JP" => Iso_2022_Jp, + "ISO-2022-JP-2" => Iso_2022_Jp_2, + "ISO-8859-6-E" => Iso_8859_6_E, + "ISO-8859-6-I" => Iso_8859_6_I, + "ISO-8859-8-E" => Iso_8859_8_E, + "ISO-8859-8-I" => Iso_8859_8_I, + "GB2312" => Gb2312, + "big5" => Big5, + "KOI8-R" => Koi8_R, + s => Ext(s.to_owned()), + }) + } +} + +#[test] +fn test_parse() { + assert_eq!(Us_Ascii, "us-ascii".parse().unwrap()); + assert_eq!(Us_Ascii, "US-Ascii".parse().unwrap()); + assert_eq!(Us_Ascii, "US-ASCII".parse().unwrap()); + assert_eq!(Shift_Jis, "Shift-JIS".parse().unwrap()); + assert_eq!(Ext("ABCD".to_owned()), "abcd".parse().unwrap()); +} + +#[test] +fn test_display() { + assert_eq!("US-ASCII", format!("{}", Us_Ascii)); + assert_eq!("ABCD", format!("{}", Ext("ABCD".to_owned()))); +} diff --git a/actix-http/src/header/shared/encoding.rs b/actix-http/src/header/shared/encoding.rs new file mode 100644 index 000000000..af7404828 --- /dev/null +++ b/actix-http/src/header/shared/encoding.rs @@ -0,0 +1,58 @@ +use std::{fmt, str}; + +pub use self::Encoding::{ + Brotli, Chunked, Compress, Deflate, EncodingExt, Gzip, Identity, Trailers, +}; + +/// A value to represent an encoding used in `Transfer-Encoding` +/// or `Accept-Encoding` header. +#[derive(Clone, PartialEq, Debug)] +pub enum Encoding { + /// The `chunked` encoding. + Chunked, + /// The `br` encoding. + Brotli, + /// The `gzip` encoding. + Gzip, + /// The `deflate` encoding. + Deflate, + /// The `compress` encoding. + Compress, + /// The `identity` encoding. + Identity, + /// The `trailers` encoding. + Trailers, + /// Some other encoding that is less common, can be any String. + EncodingExt(String), +} + +impl fmt::Display for Encoding { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match *self { + Chunked => "chunked", + Brotli => "br", + Gzip => "gzip", + Deflate => "deflate", + Compress => "compress", + Identity => "identity", + Trailers => "trailers", + EncodingExt(ref s) => s.as_ref(), + }) + } +} + +impl str::FromStr for Encoding { + type Err = crate::error::ParseError; + fn from_str(s: &str) -> Result { + match s { + "chunked" => Ok(Chunked), + "br" => Ok(Brotli), + "deflate" => Ok(Deflate), + "gzip" => Ok(Gzip), + "compress" => Ok(Compress), + "identity" => Ok(Identity), + "trailers" => Ok(Trailers), + _ => Ok(EncodingExt(s.to_owned())), + } + } +} diff --git a/actix-http/src/header/shared/entity.rs b/actix-http/src/header/shared/entity.rs new file mode 100644 index 000000000..da02dc193 --- /dev/null +++ b/actix-http/src/header/shared/entity.rs @@ -0,0 +1,265 @@ +use std::fmt::{self, Display, Write}; +use std::str::FromStr; + +use crate::header::{HeaderValue, IntoHeaderValue, InvalidHeaderValueBytes, Writer}; + +/// check that each char in the slice is either: +/// 1. `%x21`, or +/// 2. in the range `%x23` to `%x7E`, or +/// 3. above `%x80` +fn check_slice_validity(slice: &str) -> bool { + slice + .bytes() + .all(|c| c == b'\x21' || (c >= b'\x23' && c <= b'\x7e') | (c >= b'\x80')) +} + +/// An entity tag, defined in [RFC7232](https://tools.ietf.org/html/rfc7232#section-2.3) +/// +/// An entity tag consists of a string enclosed by two literal double quotes. +/// Preceding the first double quote is an optional weakness indicator, +/// which always looks like `W/`. Examples for valid tags are `"xyzzy"` and +/// `W/"xyzzy"`. +/// +/// # ABNF +/// +/// ```text +/// entity-tag = [ weak ] opaque-tag +/// weak = %x57.2F ; "W/", case-sensitive +/// opaque-tag = DQUOTE *etagc DQUOTE +/// etagc = %x21 / %x23-7E / obs-text +/// ; VCHAR except double quotes, plus obs-text +/// ``` +/// +/// # Comparison +/// To check if two entity tags are equivalent in an application always use the +/// `strong_eq` or `weak_eq` methods based on the context of the Tag. Only use +/// `==` to check if two tags are identical. +/// +/// The example below shows the results for a set of entity-tag pairs and +/// both the weak and strong comparison function results: +/// +/// | `ETag 1`| `ETag 2`| Strong Comparison | Weak Comparison | +/// |---------|---------|-------------------|-----------------| +/// | `W/"1"` | `W/"1"` | no match | match | +/// | `W/"1"` | `W/"2"` | no match | no match | +/// | `W/"1"` | `"1"` | no match | match | +/// | `"1"` | `"1"` | match | match | +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct EntityTag { + /// Weakness indicator for the tag + pub weak: bool, + /// The opaque string in between the DQUOTEs + tag: String, +} + +impl EntityTag { + /// Constructs a new EntityTag. + /// # Panics + /// If the tag contains invalid characters. + pub fn new(weak: bool, tag: String) -> EntityTag { + assert!(check_slice_validity(&tag), "Invalid tag: {:?}", tag); + EntityTag { weak, tag } + } + + /// Constructs a new weak EntityTag. + /// # Panics + /// If the tag contains invalid characters. + pub fn weak(tag: String) -> EntityTag { + EntityTag::new(true, tag) + } + + /// Constructs a new strong EntityTag. + /// # Panics + /// If the tag contains invalid characters. + pub fn strong(tag: String) -> EntityTag { + EntityTag::new(false, tag) + } + + /// Get the tag. + pub fn tag(&self) -> &str { + self.tag.as_ref() + } + + /// Set the tag. + /// # Panics + /// If the tag contains invalid characters. + pub fn set_tag(&mut self, tag: String) { + assert!(check_slice_validity(&tag), "Invalid tag: {:?}", tag); + self.tag = tag + } + + /// For strong comparison two entity-tags are equivalent if both are not + /// weak and their opaque-tags match character-by-character. + pub fn strong_eq(&self, other: &EntityTag) -> bool { + !self.weak && !other.weak && self.tag == other.tag + } + + /// For weak comparison two entity-tags are equivalent if their + /// opaque-tags match character-by-character, regardless of either or + /// both being tagged as "weak". + pub fn weak_eq(&self, other: &EntityTag) -> bool { + self.tag == other.tag + } + + /// The inverse of `EntityTag.strong_eq()`. + pub fn strong_ne(&self, other: &EntityTag) -> bool { + !self.strong_eq(other) + } + + /// The inverse of `EntityTag.weak_eq()`. + pub fn weak_ne(&self, other: &EntityTag) -> bool { + !self.weak_eq(other) + } +} + +impl Display for EntityTag { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.weak { + write!(f, "W/\"{}\"", self.tag) + } else { + write!(f, "\"{}\"", self.tag) + } + } +} + +impl FromStr for EntityTag { + type Err = crate::error::ParseError; + + fn from_str(s: &str) -> Result { + let length: usize = s.len(); + let slice = &s[..]; + // Early exits if it doesn't terminate in a DQUOTE. + if !slice.ends_with('"') || slice.len() < 2 { + return Err(crate::error::ParseError::Header); + } + // The etag is weak if its first char is not a DQUOTE. + if slice.len() >= 2 + && slice.starts_with('"') + && check_slice_validity(&slice[1..length - 1]) + { + // No need to check if the last char is a DQUOTE, + // we already did that above. + return Ok(EntityTag { + weak: false, + tag: slice[1..length - 1].to_owned(), + }); + } else if slice.len() >= 4 + && slice.starts_with("W/\"") + && check_slice_validity(&slice[3..length - 1]) + { + return Ok(EntityTag { + weak: true, + tag: slice[3..length - 1].to_owned(), + }); + } + Err(crate::error::ParseError::Header) + } +} + +impl IntoHeaderValue for EntityTag { + type Error = InvalidHeaderValueBytes; + + fn try_into(self) -> Result { + let mut wrt = Writer::new(); + write!(wrt, "{}", self).unwrap(); + HeaderValue::from_shared(wrt.take()) + } +} + +#[cfg(test)] +mod tests { + use super::EntityTag; + + #[test] + fn test_etag_parse_success() { + // Expected success + assert_eq!( + "\"foobar\"".parse::().unwrap(), + EntityTag::strong("foobar".to_owned()) + ); + assert_eq!( + "\"\"".parse::().unwrap(), + EntityTag::strong("".to_owned()) + ); + assert_eq!( + "W/\"weaktag\"".parse::().unwrap(), + EntityTag::weak("weaktag".to_owned()) + ); + assert_eq!( + "W/\"\x65\x62\"".parse::().unwrap(), + EntityTag::weak("\x65\x62".to_owned()) + ); + assert_eq!( + "W/\"\"".parse::().unwrap(), + EntityTag::weak("".to_owned()) + ); + } + + #[test] + fn test_etag_parse_failures() { + // Expected failures + assert!("no-dquotes".parse::().is_err()); + assert!("w/\"the-first-w-is-case-sensitive\"" + .parse::() + .is_err()); + assert!("".parse::().is_err()); + assert!("\"unmatched-dquotes1".parse::().is_err()); + assert!("unmatched-dquotes2\"".parse::().is_err()); + assert!("matched-\"dquotes\"".parse::().is_err()); + } + + #[test] + fn test_etag_fmt() { + assert_eq!( + format!("{}", EntityTag::strong("foobar".to_owned())), + "\"foobar\"" + ); + assert_eq!(format!("{}", EntityTag::strong("".to_owned())), "\"\""); + assert_eq!( + format!("{}", EntityTag::weak("weak-etag".to_owned())), + "W/\"weak-etag\"" + ); + assert_eq!( + format!("{}", EntityTag::weak("\u{0065}".to_owned())), + "W/\"\x65\"" + ); + assert_eq!(format!("{}", EntityTag::weak("".to_owned())), "W/\"\""); + } + + #[test] + fn test_cmp() { + // | ETag 1 | ETag 2 | Strong Comparison | Weak Comparison | + // |---------|---------|-------------------|-----------------| + // | `W/"1"` | `W/"1"` | no match | match | + // | `W/"1"` | `W/"2"` | no match | no match | + // | `W/"1"` | `"1"` | no match | match | + // | `"1"` | `"1"` | match | match | + let mut etag1 = EntityTag::weak("1".to_owned()); + let mut etag2 = EntityTag::weak("1".to_owned()); + assert!(!etag1.strong_eq(&etag2)); + assert!(etag1.weak_eq(&etag2)); + assert!(etag1.strong_ne(&etag2)); + assert!(!etag1.weak_ne(&etag2)); + + etag1 = EntityTag::weak("1".to_owned()); + etag2 = EntityTag::weak("2".to_owned()); + assert!(!etag1.strong_eq(&etag2)); + assert!(!etag1.weak_eq(&etag2)); + assert!(etag1.strong_ne(&etag2)); + assert!(etag1.weak_ne(&etag2)); + + etag1 = EntityTag::weak("1".to_owned()); + etag2 = EntityTag::strong("1".to_owned()); + assert!(!etag1.strong_eq(&etag2)); + assert!(etag1.weak_eq(&etag2)); + assert!(etag1.strong_ne(&etag2)); + assert!(!etag1.weak_ne(&etag2)); + + etag1 = EntityTag::strong("1".to_owned()); + etag2 = EntityTag::strong("1".to_owned()); + assert!(etag1.strong_eq(&etag2)); + assert!(etag1.weak_eq(&etag2)); + assert!(!etag1.strong_ne(&etag2)); + assert!(!etag1.weak_ne(&etag2)); + } +} diff --git a/actix-http/src/header/shared/httpdate.rs b/actix-http/src/header/shared/httpdate.rs new file mode 100644 index 000000000..350f77bbe --- /dev/null +++ b/actix-http/src/header/shared/httpdate.rs @@ -0,0 +1,118 @@ +use std::fmt::{self, Display}; +use std::io::Write; +use std::str::FromStr; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use bytes::{BufMut, BytesMut}; +use http::header::{HeaderValue, InvalidHeaderValueBytes}; + +use crate::error::ParseError; +use crate::header::IntoHeaderValue; + +/// A timestamp with HTTP formatting and parsing +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct HttpDate(time::Tm); + +impl FromStr for HttpDate { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + match time::strptime(s, "%a, %d %b %Y %T %Z") + .or_else(|_| time::strptime(s, "%A, %d-%b-%y %T %Z")) + .or_else(|_| time::strptime(s, "%c")) + { + Ok(t) => Ok(HttpDate(t)), + Err(_) => Err(ParseError::Header), + } + } +} + +impl Display for HttpDate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.0.to_utc().rfc822(), f) + } +} + +impl From for HttpDate { + fn from(tm: time::Tm) -> HttpDate { + HttpDate(tm) + } +} + +impl From for HttpDate { + fn from(sys: SystemTime) -> HttpDate { + let tmspec = match sys.duration_since(UNIX_EPOCH) { + Ok(dur) => { + time::Timespec::new(dur.as_secs() as i64, dur.subsec_nanos() as i32) + } + Err(err) => { + let neg = err.duration(); + time::Timespec::new( + -(neg.as_secs() as i64), + -(neg.subsec_nanos() as i32), + ) + } + }; + HttpDate(time::at_utc(tmspec)) + } +} + +impl IntoHeaderValue for HttpDate { + type Error = InvalidHeaderValueBytes; + + fn try_into(self) -> Result { + let mut wrt = BytesMut::with_capacity(29).writer(); + write!(wrt, "{}", self.0.rfc822()).unwrap(); + HeaderValue::from_shared(wrt.get_mut().take().freeze()) + } +} + +impl From for SystemTime { + fn from(date: HttpDate) -> SystemTime { + let spec = date.0.to_timespec(); + if spec.sec >= 0 { + UNIX_EPOCH + Duration::new(spec.sec as u64, spec.nsec as u32) + } else { + UNIX_EPOCH - Duration::new(spec.sec as u64, spec.nsec as u32) + } + } +} + +#[cfg(test)] +mod tests { + use super::HttpDate; + use time::Tm; + + const NOV_07: HttpDate = HttpDate(Tm { + tm_nsec: 0, + tm_sec: 37, + tm_min: 48, + tm_hour: 8, + tm_mday: 7, + tm_mon: 10, + tm_year: 94, + tm_wday: 0, + tm_isdst: 0, + tm_yday: 0, + tm_utcoff: 0, + }); + + #[test] + fn test_date() { + assert_eq!( + "Sun, 07 Nov 1994 08:48:37 GMT".parse::().unwrap(), + NOV_07 + ); + assert_eq!( + "Sunday, 07-Nov-94 08:48:37 GMT" + .parse::() + .unwrap(), + NOV_07 + ); + assert_eq!( + "Sun Nov 7 08:48:37 1994".parse::().unwrap(), + NOV_07 + ); + assert!("this-is-no-date".parse::().is_err()); + } +} diff --git a/actix-http/src/header/shared/mod.rs b/actix-http/src/header/shared/mod.rs new file mode 100644 index 000000000..f2bc91634 --- /dev/null +++ b/actix-http/src/header/shared/mod.rs @@ -0,0 +1,14 @@ +//! Copied for `hyper::header::shared`; + +pub use self::charset::Charset; +pub use self::encoding::Encoding; +pub use self::entity::EntityTag; +pub use self::httpdate::HttpDate; +pub use self::quality_item::{q, qitem, Quality, QualityItem}; +pub use language_tags::LanguageTag; + +mod charset; +mod encoding; +mod entity; +mod httpdate; +mod quality_item; diff --git a/actix-http/src/header/shared/quality_item.rs b/actix-http/src/header/shared/quality_item.rs new file mode 100644 index 000000000..fc3930c5e --- /dev/null +++ b/actix-http/src/header/shared/quality_item.rs @@ -0,0 +1,291 @@ +use std::{cmp, fmt, str}; + +use self::internal::IntoQuality; + +/// Represents a quality used in quality values. +/// +/// Can be created with the `q` function. +/// +/// # Implementation notes +/// +/// The quality value is defined as a number between 0 and 1 with three decimal +/// places. This means there are 1001 possible values. Since floating point +/// numbers are not exact and the smallest floating point data type (`f32`) +/// consumes four bytes, hyper uses an `u16` value to store the +/// quality internally. For performance reasons you may set quality directly to +/// a value between 0 and 1000 e.g. `Quality(532)` matches the quality +/// `q=0.532`. +/// +/// [RFC7231 Section 5.3.1](https://tools.ietf.org/html/rfc7231#section-5.3.1) +/// gives more information on quality values in HTTP header fields. +#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub struct Quality(u16); + +impl Default for Quality { + fn default() -> Quality { + Quality(1000) + } +} + +/// Represents an item with a quality value as defined in +/// [RFC7231](https://tools.ietf.org/html/rfc7231#section-5.3.1). +#[derive(Clone, PartialEq, Debug)] +pub struct QualityItem { + /// The actual contents of the field. + pub item: T, + /// The quality (client or server preference) for the value. + pub quality: Quality, +} + +impl QualityItem { + /// Creates a new `QualityItem` from an item and a quality. + /// The item can be of any type. + /// The quality should be a value in the range [0, 1]. + pub fn new(item: T, quality: Quality) -> QualityItem { + QualityItem { item, quality } + } +} + +impl cmp::PartialOrd for QualityItem { + fn partial_cmp(&self, other: &QualityItem) -> Option { + self.quality.partial_cmp(&other.quality) + } +} + +impl fmt::Display for QualityItem { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.item, f)?; + match self.quality.0 { + 1000 => Ok(()), + 0 => f.write_str("; q=0"), + x => write!(f, "; q=0.{}", format!("{:03}", x).trim_end_matches('0')), + } + } +} + +impl str::FromStr for QualityItem { + type Err = crate::error::ParseError; + + fn from_str(s: &str) -> Result, crate::error::ParseError> { + if !s.is_ascii() { + return Err(crate::error::ParseError::Header); + } + // Set defaults used if parsing fails. + let mut raw_item = s; + let mut quality = 1f32; + + let parts: Vec<&str> = s.rsplitn(2, ';').map(|x| x.trim()).collect(); + if parts.len() == 2 { + if parts[0].len() < 2 { + return Err(crate::error::ParseError::Header); + } + let start = &parts[0][0..2]; + if start == "q=" || start == "Q=" { + let q_part = &parts[0][2..parts[0].len()]; + if q_part.len() > 5 { + return Err(crate::error::ParseError::Header); + } + match q_part.parse::() { + Ok(q_value) => { + if 0f32 <= q_value && q_value <= 1f32 { + quality = q_value; + raw_item = parts[1]; + } else { + return Err(crate::error::ParseError::Header); + } + } + Err(_) => return Err(crate::error::ParseError::Header), + } + } + } + match raw_item.parse::() { + // we already checked above that the quality is within range + Ok(item) => Ok(QualityItem::new(item, from_f32(quality))), + Err(_) => Err(crate::error::ParseError::Header), + } + } +} + +#[inline] +fn from_f32(f: f32) -> Quality { + // this function is only used internally. A check that `f` is within range + // should be done before calling this method. Just in case, this + // debug_assert should catch if we were forgetful + debug_assert!( + f >= 0f32 && f <= 1f32, + "q value must be between 0.0 and 1.0" + ); + Quality((f * 1000f32) as u16) +} + +/// Convenience function to wrap a value in a `QualityItem` +/// Sets `q` to the default 1.0 +pub fn qitem(item: T) -> QualityItem { + QualityItem::new(item, Default::default()) +} + +/// Convenience function to create a `Quality` from a float or integer. +/// +/// Implemented for `u16` and `f32`. Panics if value is out of range. +pub fn q(val: T) -> Quality { + val.into_quality() +} + +mod internal { + use super::Quality; + + // TryFrom is probably better, but it's not stable. For now, we want to + // keep the functionality of the `q` function, while allowing it to be + // generic over `f32` and `u16`. + // + // `q` would panic before, so keep that behavior. `TryFrom` can be + // introduced later for a non-panicking conversion. + + pub trait IntoQuality: Sealed + Sized { + fn into_quality(self) -> Quality; + } + + impl IntoQuality for f32 { + fn into_quality(self) -> Quality { + assert!( + self >= 0f32 && self <= 1f32, + "float must be between 0.0 and 1.0" + ); + super::from_f32(self) + } + } + + impl IntoQuality for u16 { + fn into_quality(self) -> Quality { + assert!(self <= 1000, "u16 must be between 0 and 1000"); + Quality(self) + } + } + + pub trait Sealed {} + impl Sealed for u16 {} + impl Sealed for f32 {} +} + +#[cfg(test)] +mod tests { + use super::super::encoding::*; + use super::*; + + #[test] + fn test_quality_item_fmt_q_1() { + let x = qitem(Chunked); + assert_eq!(format!("{}", x), "chunked"); + } + #[test] + fn test_quality_item_fmt_q_0001() { + let x = QualityItem::new(Chunked, Quality(1)); + assert_eq!(format!("{}", x), "chunked; q=0.001"); + } + #[test] + fn test_quality_item_fmt_q_05() { + // Custom value + let x = QualityItem { + item: EncodingExt("identity".to_owned()), + quality: Quality(500), + }; + assert_eq!(format!("{}", x), "identity; q=0.5"); + } + + #[test] + fn test_quality_item_fmt_q_0() { + // Custom value + let x = QualityItem { + item: EncodingExt("identity".to_owned()), + quality: Quality(0), + }; + assert_eq!(x.to_string(), "identity; q=0"); + } + + #[test] + fn test_quality_item_from_str1() { + let x: Result, _> = "chunked".parse(); + assert_eq!( + x.unwrap(), + QualityItem { + item: Chunked, + quality: Quality(1000), + } + ); + } + #[test] + fn test_quality_item_from_str2() { + let x: Result, _> = "chunked; q=1".parse(); + assert_eq!( + x.unwrap(), + QualityItem { + item: Chunked, + quality: Quality(1000), + } + ); + } + #[test] + fn test_quality_item_from_str3() { + let x: Result, _> = "gzip; q=0.5".parse(); + assert_eq!( + x.unwrap(), + QualityItem { + item: Gzip, + quality: Quality(500), + } + ); + } + #[test] + fn test_quality_item_from_str4() { + let x: Result, _> = "gzip; q=0.273".parse(); + assert_eq!( + x.unwrap(), + QualityItem { + item: Gzip, + quality: Quality(273), + } + ); + } + #[test] + fn test_quality_item_from_str5() { + let x: Result, _> = "gzip; q=0.2739999".parse(); + assert!(x.is_err()); + } + #[test] + fn test_quality_item_from_str6() { + let x: Result, _> = "gzip; q=2".parse(); + assert!(x.is_err()); + } + #[test] + fn test_quality_item_ordering() { + let x: QualityItem = "gzip; q=0.5".parse().ok().unwrap(); + let y: QualityItem = "gzip; q=0.273".parse().ok().unwrap(); + let comparision_result: bool = x.gt(&y); + assert!(comparision_result) + } + + #[test] + fn test_quality() { + assert_eq!(q(0.5), Quality(500)); + } + + #[test] + #[should_panic] // FIXME - 32-bit msvc unwinding broken + #[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)] + fn test_quality_invalid() { + q(-1.0); + } + + #[test] + #[should_panic] // FIXME - 32-bit msvc unwinding broken + #[cfg_attr(all(target_arch = "x86", target_env = "msvc"), ignore)] + fn test_quality_invalid2() { + q(2.0); + } + + #[test] + fn test_fuzzing_bugs() { + assert!("99999;".parse::>().is_err()); + assert!("\x0d;;;=\u{d6aa}==".parse::>().is_err()) + } +} diff --git a/actix-http/src/helpers.rs b/actix-http/src/helpers.rs new file mode 100644 index 000000000..e4ccd8aef --- /dev/null +++ b/actix-http/src/helpers.rs @@ -0,0 +1,208 @@ +use bytes::{BufMut, BytesMut}; +use http::Version; +use std::{mem, ptr, slice}; + +const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\ + 2021222324252627282930313233343536373839\ + 4041424344454647484950515253545556575859\ + 6061626364656667686970717273747576777879\ + 8081828384858687888990919293949596979899"; + +pub(crate) const STATUS_LINE_BUF_SIZE: usize = 13; + +pub(crate) fn write_status_line(version: Version, mut n: u16, bytes: &mut BytesMut) { + let mut buf: [u8; STATUS_LINE_BUF_SIZE] = [ + b'H', b'T', b'T', b'P', b'/', b'1', b'.', b'1', b' ', b' ', b' ', b' ', b' ', + ]; + match version { + Version::HTTP_2 => buf[5] = b'2', + Version::HTTP_10 => buf[7] = b'0', + Version::HTTP_09 => { + buf[5] = b'0'; + buf[7] = b'9'; + } + _ => (), + } + + let mut curr: isize = 12; + let buf_ptr = buf.as_mut_ptr(); + let lut_ptr = DEC_DIGITS_LUT.as_ptr(); + let four = n > 999; + + // decode 2 more chars, if > 2 chars + let d1 = (n % 100) << 1; + n /= 100; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(curr), 2); + } + + // decode last 1 or 2 chars + if n < 10 { + curr -= 1; + unsafe { + *buf_ptr.offset(curr) = (n as u8) + b'0'; + } + } else { + let d1 = n << 1; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping( + lut_ptr.offset(d1 as isize), + buf_ptr.offset(curr), + 2, + ); + } + } + + bytes.put_slice(&buf); + if four { + bytes.put(b' '); + } +} + +/// NOTE: bytes object has to contain enough space +pub fn write_content_length(mut n: usize, bytes: &mut BytesMut) { + if n < 10 { + let mut buf: [u8; 21] = [ + b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', + b'n', b'g', b't', b'h', b':', b' ', b'0', b'\r', b'\n', + ]; + buf[18] = (n as u8) + b'0'; + bytes.put_slice(&buf); + } else if n < 100 { + let mut buf: [u8; 22] = [ + b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', + b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'\r', b'\n', + ]; + let d1 = n << 1; + unsafe { + ptr::copy_nonoverlapping( + DEC_DIGITS_LUT.as_ptr().add(d1), + buf.as_mut_ptr().offset(18), + 2, + ); + } + bytes.put_slice(&buf); + } else if n < 1000 { + let mut buf: [u8; 23] = [ + b'\r', b'\n', b'c', b'o', b'n', b't', b'e', b'n', b't', b'-', b'l', b'e', + b'n', b'g', b't', b'h', b':', b' ', b'0', b'0', b'0', b'\r', b'\n', + ]; + // decode 2 more chars, if > 2 chars + let d1 = (n % 100) << 1; + n /= 100; + unsafe { + ptr::copy_nonoverlapping( + DEC_DIGITS_LUT.as_ptr().add(d1), + buf.as_mut_ptr().offset(19), + 2, + ) + }; + + // decode last 1 + buf[18] = (n as u8) + b'0'; + + bytes.put_slice(&buf); + } else { + bytes.put_slice(b"\r\ncontent-length: "); + convert_usize(n, bytes); + } +} + +pub(crate) fn convert_usize(mut n: usize, bytes: &mut BytesMut) { + let mut curr: isize = 39; + let mut buf: [u8; 41] = unsafe { mem::uninitialized() }; + buf[39] = b'\r'; + buf[40] = b'\n'; + let buf_ptr = buf.as_mut_ptr(); + let lut_ptr = DEC_DIGITS_LUT.as_ptr(); + + // eagerly decode 4 characters at a time + while n >= 10_000 { + let rem = (n % 10_000) as isize; + n /= 10_000; + + let d1 = (rem / 100) << 1; + let d2 = (rem % 100) << 1; + curr -= 4; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); + ptr::copy_nonoverlapping(lut_ptr.offset(d2), buf_ptr.offset(curr + 2), 2); + } + } + + // if we reach here numbers are <= 9999, so at most 4 chars long + let mut n = n as isize; // possibly reduce 64bit math + + // decode 2 more chars, if > 2 chars + if n >= 100 { + let d1 = (n % 100) << 1; + n /= 100; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); + } + } + + // decode last 1 or 2 chars + if n < 10 { + curr -= 1; + unsafe { + *buf_ptr.offset(curr) = (n as u8) + b'0'; + } + } else { + let d1 = n << 1; + curr -= 2; + unsafe { + ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2); + } + } + + unsafe { + bytes.extend_from_slice(slice::from_raw_parts( + buf_ptr.offset(curr), + 41 - curr as usize, + )); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_write_content_length() { + let mut bytes = BytesMut::new(); + bytes.reserve(50); + write_content_length(0, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 0\r\n"[..]); + bytes.reserve(50); + write_content_length(9, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 9\r\n"[..]); + bytes.reserve(50); + write_content_length(10, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 10\r\n"[..]); + bytes.reserve(50); + write_content_length(99, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 99\r\n"[..]); + bytes.reserve(50); + write_content_length(100, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 100\r\n"[..]); + bytes.reserve(50); + write_content_length(101, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 101\r\n"[..]); + bytes.reserve(50); + write_content_length(998, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 998\r\n"[..]); + bytes.reserve(50); + write_content_length(1000, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1000\r\n"[..]); + bytes.reserve(50); + write_content_length(1001, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 1001\r\n"[..]); + bytes.reserve(50); + write_content_length(5909, &mut bytes); + assert_eq!(bytes.take().freeze(), b"\r\ncontent-length: 5909\r\n"[..]); + } +} diff --git a/actix-http/src/httpcodes.rs b/actix-http/src/httpcodes.rs new file mode 100644 index 000000000..5dfeefa9e --- /dev/null +++ b/actix-http/src/httpcodes.rs @@ -0,0 +1,86 @@ +//! Basic http responses +#![allow(non_upper_case_globals)] +use http::StatusCode; + +use crate::response::{Response, ResponseBuilder}; + +macro_rules! STATIC_RESP { + ($name:ident, $status:expr) => { + #[allow(non_snake_case, missing_docs)] + pub fn $name() -> ResponseBuilder { + Response::build($status) + } + }; +} + +impl Response { + STATIC_RESP!(Ok, StatusCode::OK); + STATIC_RESP!(Created, StatusCode::CREATED); + STATIC_RESP!(Accepted, StatusCode::ACCEPTED); + STATIC_RESP!( + NonAuthoritativeInformation, + StatusCode::NON_AUTHORITATIVE_INFORMATION + ); + + STATIC_RESP!(NoContent, StatusCode::NO_CONTENT); + STATIC_RESP!(ResetContent, StatusCode::RESET_CONTENT); + STATIC_RESP!(PartialContent, StatusCode::PARTIAL_CONTENT); + STATIC_RESP!(MultiStatus, StatusCode::MULTI_STATUS); + STATIC_RESP!(AlreadyReported, StatusCode::ALREADY_REPORTED); + + STATIC_RESP!(MultipleChoices, StatusCode::MULTIPLE_CHOICES); + STATIC_RESP!(MovedPermanenty, StatusCode::MOVED_PERMANENTLY); + STATIC_RESP!(MovedPermanently, StatusCode::MOVED_PERMANENTLY); + STATIC_RESP!(Found, StatusCode::FOUND); + STATIC_RESP!(SeeOther, StatusCode::SEE_OTHER); + STATIC_RESP!(NotModified, StatusCode::NOT_MODIFIED); + STATIC_RESP!(UseProxy, StatusCode::USE_PROXY); + STATIC_RESP!(TemporaryRedirect, StatusCode::TEMPORARY_REDIRECT); + STATIC_RESP!(PermanentRedirect, StatusCode::PERMANENT_REDIRECT); + + STATIC_RESP!(BadRequest, StatusCode::BAD_REQUEST); + STATIC_RESP!(NotFound, StatusCode::NOT_FOUND); + STATIC_RESP!(Unauthorized, StatusCode::UNAUTHORIZED); + STATIC_RESP!(PaymentRequired, StatusCode::PAYMENT_REQUIRED); + STATIC_RESP!(Forbidden, StatusCode::FORBIDDEN); + STATIC_RESP!(MethodNotAllowed, StatusCode::METHOD_NOT_ALLOWED); + STATIC_RESP!(NotAcceptable, StatusCode::NOT_ACCEPTABLE); + STATIC_RESP!( + ProxyAuthenticationRequired, + StatusCode::PROXY_AUTHENTICATION_REQUIRED + ); + STATIC_RESP!(RequestTimeout, StatusCode::REQUEST_TIMEOUT); + STATIC_RESP!(Conflict, StatusCode::CONFLICT); + STATIC_RESP!(Gone, StatusCode::GONE); + STATIC_RESP!(LengthRequired, StatusCode::LENGTH_REQUIRED); + STATIC_RESP!(PreconditionFailed, StatusCode::PRECONDITION_FAILED); + STATIC_RESP!(PreconditionRequired, StatusCode::PRECONDITION_REQUIRED); + STATIC_RESP!(PayloadTooLarge, StatusCode::PAYLOAD_TOO_LARGE); + STATIC_RESP!(UriTooLong, StatusCode::URI_TOO_LONG); + STATIC_RESP!(UnsupportedMediaType, StatusCode::UNSUPPORTED_MEDIA_TYPE); + STATIC_RESP!(RangeNotSatisfiable, StatusCode::RANGE_NOT_SATISFIABLE); + STATIC_RESP!(ExpectationFailed, StatusCode::EXPECTATION_FAILED); + + STATIC_RESP!(InternalServerError, StatusCode::INTERNAL_SERVER_ERROR); + STATIC_RESP!(NotImplemented, StatusCode::NOT_IMPLEMENTED); + STATIC_RESP!(BadGateway, StatusCode::BAD_GATEWAY); + STATIC_RESP!(ServiceUnavailable, StatusCode::SERVICE_UNAVAILABLE); + STATIC_RESP!(GatewayTimeout, StatusCode::GATEWAY_TIMEOUT); + STATIC_RESP!(VersionNotSupported, StatusCode::HTTP_VERSION_NOT_SUPPORTED); + STATIC_RESP!(VariantAlsoNegotiates, StatusCode::VARIANT_ALSO_NEGOTIATES); + STATIC_RESP!(InsufficientStorage, StatusCode::INSUFFICIENT_STORAGE); + STATIC_RESP!(LoopDetected, StatusCode::LOOP_DETECTED); +} + +#[cfg(test)] +mod tests { + use crate::body::Body; + use crate::response::Response; + use http::StatusCode; + + #[test] + fn test_build() { + let resp = Response::Ok().body(Body::Empty); + assert_eq!(resp.status(), StatusCode::OK); + } +} diff --git a/actix-http/src/httpmessage.rs b/actix-http/src/httpmessage.rs new file mode 100644 index 000000000..60821d300 --- /dev/null +++ b/actix-http/src/httpmessage.rs @@ -0,0 +1,269 @@ +use std::cell::{Ref, RefMut}; +use std::str; + +use encoding::all::UTF_8; +use encoding::label::encoding_from_whatwg_label; +use encoding::EncodingRef; +use http::{header, HeaderMap}; +use mime::Mime; + +use crate::error::{ContentTypeError, ParseError}; +use crate::extensions::Extensions; +use crate::header::Header; +use crate::payload::Payload; + +#[cfg(feature = "cookies")] +use crate::error::CookieParseError; +#[cfg(feature = "cookies")] +use cookie::Cookie; + +#[cfg(feature = "cookies")] +struct Cookies(Vec>); + +/// Trait that implements general purpose operations on http messages +pub trait HttpMessage: Sized { + /// Type of message payload stream + type Stream; + + /// Read the message headers. + fn headers(&self) -> &HeaderMap; + + /// Message payload stream + fn take_payload(&mut self) -> Payload; + + /// Request's extensions container + fn extensions(&self) -> Ref; + + /// Mutable reference to a the request's extensions container + fn extensions_mut(&self) -> RefMut; + + #[doc(hidden)] + /// Get a header + fn get_header(&self) -> Option + where + Self: Sized, + { + if self.headers().contains_key(H::name()) { + H::parse(self).ok() + } else { + None + } + } + + /// Read the request content type. If request does not contain + /// *Content-Type* header, empty str get returned. + fn content_type(&self) -> &str { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return content_type.split(';').next().unwrap().trim(); + } + } + "" + } + + /// Get content type encoding + /// + /// UTF-8 is used by default, If request charset is not set. + fn encoding(&self) -> Result { + if let Some(mime_type) = self.mime_type()? { + if let Some(charset) = mime_type.get_param("charset") { + if let Some(enc) = encoding_from_whatwg_label(charset.as_str()) { + Ok(enc) + } else { + Err(ContentTypeError::UnknownEncoding) + } + } else { + Ok(UTF_8) + } + } else { + Ok(UTF_8) + } + } + + /// Convert the request content type to a known mime type. + fn mime_type(&self) -> Result, ContentTypeError> { + if let Some(content_type) = self.headers().get(header::CONTENT_TYPE) { + if let Ok(content_type) = content_type.to_str() { + return match content_type.parse() { + Ok(mt) => Ok(Some(mt)), + Err(_) => Err(ContentTypeError::ParseError), + }; + } else { + return Err(ContentTypeError::ParseError); + } + } + Ok(None) + } + + /// Check if request has chunked transfer encoding + fn chunked(&self) -> Result { + if let Some(encodings) = self.headers().get(header::TRANSFER_ENCODING) { + if let Ok(s) = encodings.to_str() { + Ok(s.to_lowercase().contains("chunked")) + } else { + Err(ParseError::Header) + } + } else { + Ok(false) + } + } + + /// Load request cookies. + #[inline] + #[cfg(feature = "cookies")] + fn cookies(&self) -> Result>>, CookieParseError> { + if self.extensions().get::().is_none() { + let mut cookies = Vec::new(); + for hdr in self.headers().get_all(header::COOKIE) { + let s = + str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?; + for cookie_str in s.split(';').map(|s| s.trim()) { + if !cookie_str.is_empty() { + cookies.push(Cookie::parse_encoded(cookie_str)?.into_owned()); + } + } + } + self.extensions_mut().insert(Cookies(cookies)); + } + Ok(Ref::map(self.extensions(), |ext| { + &ext.get::().unwrap().0 + })) + } + + /// Return request cookie. + #[cfg(feature = "cookies")] + fn cookie(&self, name: &str) -> Option> { + if let Ok(cookies) = self.cookies() { + for cookie in cookies.iter() { + if cookie.name() == name { + return Some(cookie.to_owned()); + } + } + } + None + } +} + +impl<'a, T> HttpMessage for &'a mut T +where + T: HttpMessage, +{ + type Stream = T::Stream; + + fn headers(&self) -> &HeaderMap { + (**self).headers() + } + + /// Message payload stream + fn take_payload(&mut self) -> Payload { + (**self).take_payload() + } + + /// Request's extensions container + fn extensions(&self) -> Ref { + (**self).extensions() + } + + /// Mutable reference to a the request's extensions container + fn extensions_mut(&self) -> RefMut { + (**self).extensions_mut() + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use encoding::all::ISO_8859_2; + use encoding::Encoding; + use mime; + + use super::*; + use crate::test::TestRequest; + + #[test] + fn test_content_type() { + let req = TestRequest::with_header("content-type", "text/plain").finish(); + assert_eq!(req.content_type(), "text/plain"); + let req = + TestRequest::with_header("content-type", "application/json; charset=utf=8") + .finish(); + assert_eq!(req.content_type(), "application/json"); + let req = TestRequest::default().finish(); + assert_eq!(req.content_type(), ""); + } + + #[test] + fn test_mime_type() { + let req = TestRequest::with_header("content-type", "application/json").finish(); + assert_eq!(req.mime_type().unwrap(), Some(mime::APPLICATION_JSON)); + let req = TestRequest::default().finish(); + assert_eq!(req.mime_type().unwrap(), None); + let req = + TestRequest::with_header("content-type", "application/json; charset=utf-8") + .finish(); + let mt = req.mime_type().unwrap().unwrap(); + assert_eq!(mt.get_param(mime::CHARSET), Some(mime::UTF_8)); + assert_eq!(mt.type_(), mime::APPLICATION); + assert_eq!(mt.subtype(), mime::JSON); + } + + #[test] + fn test_mime_type_error() { + let req = TestRequest::with_header( + "content-type", + "applicationadfadsfasdflknadsfklnadsfjson", + ) + .finish(); + assert_eq!(Err(ContentTypeError::ParseError), req.mime_type()); + } + + #[test] + fn test_encoding() { + let req = TestRequest::default().finish(); + assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); + + let req = TestRequest::with_header("content-type", "application/json").finish(); + assert_eq!(UTF_8.name(), req.encoding().unwrap().name()); + + let req = TestRequest::with_header( + "content-type", + "application/json; charset=ISO-8859-2", + ) + .finish(); + assert_eq!(ISO_8859_2.name(), req.encoding().unwrap().name()); + } + + #[test] + fn test_encoding_error() { + let req = TestRequest::with_header("content-type", "applicatjson").finish(); + assert_eq!(Some(ContentTypeError::ParseError), req.encoding().err()); + + let req = TestRequest::with_header( + "content-type", + "application/json; charset=kkkttktk", + ) + .finish(); + assert_eq!( + Some(ContentTypeError::UnknownEncoding), + req.encoding().err() + ); + } + + #[test] + fn test_chunked() { + let req = TestRequest::default().finish(); + assert!(!req.chunked().unwrap()); + + let req = + TestRequest::with_header(header::TRANSFER_ENCODING, "chunked").finish(); + assert!(req.chunked().unwrap()); + + let req = TestRequest::default() + .header( + header::TRANSFER_ENCODING, + Bytes::from_static(b"some va\xadscc\xacas0xsdasdlue"), + ) + .finish(); + assert!(req.chunked().is_err()); + } +} diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs new file mode 100644 index 000000000..b41ce7ae8 --- /dev/null +++ b/actix-http/src/lib.rs @@ -0,0 +1,65 @@ +//! Basic http primitives for actix-net framework. +#![allow( + clippy::type_complexity, + clippy::new_without_default, + clippy::new_without_default_derive +)] + +#[macro_use] +extern crate log; + +pub mod body; +mod builder; +pub mod client; +mod config; +mod extensions; +mod header; +mod helpers; +mod httpcodes; +pub mod httpmessage; +mod message; +mod payload; +mod request; +mod response; +mod service; + +pub mod error; +pub mod h1; +pub mod h2; +pub mod test; +pub mod ws; + +pub use self::builder::HttpServiceBuilder; +pub use self::config::{KeepAlive, ServiceConfig}; +pub use self::error::{Error, ResponseError, Result}; +pub use self::extensions::Extensions; +pub use self::httpmessage::HttpMessage; +pub use self::message::{Head, Message, RequestHead, ResponseHead}; +pub use self::payload::{Payload, PayloadStream}; +pub use self::request::Request; +pub use self::response::{Response, ResponseBuilder}; +pub use self::service::{HttpService, SendError, SendResponse}; + +pub mod http { + //! Various HTTP related types + + // re-exports + pub use http::header::{HeaderName, HeaderValue}; + pub use http::{Method, StatusCode, Version}; + + #[doc(hidden)] + pub use http::{uri, Error, HeaderMap, HttpTryFrom, Uri}; + + #[doc(hidden)] + pub use http::uri::PathAndQuery; + + #[cfg(feature = "cookies")] + pub use cookie::{Cookie, CookieBuilder}; + + /// Various http headers + pub mod header { + pub use crate::header::*; + } + pub use crate::header::ContentEncoding; + pub use crate::message::ConnectionType; +} diff --git a/actix-http/src/message.rs b/actix-http/src/message.rs new file mode 100644 index 000000000..4e46093f6 --- /dev/null +++ b/actix-http/src/message.rs @@ -0,0 +1,315 @@ +use std::cell::{Ref, RefCell, RefMut}; +use std::collections::VecDeque; +use std::rc::Rc; + +use crate::extensions::Extensions; +use crate::http::{header, HeaderMap, Method, StatusCode, Uri, Version}; + +/// Represents various types of connection +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum ConnectionType { + /// Close connection after response + Close, + /// Keep connection alive after response + KeepAlive, + /// Connection is upgraded to different type + Upgrade, +} + +#[doc(hidden)] +pub trait Head: Default + 'static { + fn clear(&mut self); + + /// Read the message headers. + fn headers(&self) -> &HeaderMap; + + /// Mutable reference to the message headers. + fn headers_mut(&mut self) -> &mut HeaderMap; + + /// Connection type + fn connection_type(&self) -> ConnectionType; + + /// Set connection type of the message + fn set_connection_type(&mut self, ctype: ConnectionType); + + fn upgrade(&self) -> bool { + if let Some(hdr) = self.headers().get(header::CONNECTION) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("upgrade") + } else { + false + } + } else { + false + } + } + + /// Check if keep-alive is enabled + fn keep_alive(&self) -> bool { + self.connection_type() == ConnectionType::KeepAlive + } + + fn pool() -> &'static MessagePool; +} + +#[derive(Debug)] +pub struct RequestHead { + pub uri: Uri, + pub method: Method, + pub version: Version, + pub headers: HeaderMap, + pub ctype: Option, + pub no_chunking: bool, + pub extensions: RefCell, +} + +impl Default for RequestHead { + fn default() -> RequestHead { + RequestHead { + uri: Uri::default(), + method: Method::default(), + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + ctype: None, + no_chunking: false, + extensions: RefCell::new(Extensions::new()), + } + } +} + +impl Head for RequestHead { + fn clear(&mut self) { + self.ctype = None; + self.headers.clear(); + self.extensions.borrow_mut().clear(); + } + + fn headers(&self) -> &HeaderMap { + &self.headers + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + fn set_connection_type(&mut self, ctype: ConnectionType) { + self.ctype = Some(ctype) + } + + fn connection_type(&self) -> ConnectionType { + if let Some(ct) = self.ctype { + ct + } else if self.version < Version::HTTP_11 { + ConnectionType::Close + } else { + ConnectionType::KeepAlive + } + } + + fn upgrade(&self) -> bool { + if let Some(hdr) = self.headers().get(header::CONNECTION) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("upgrade") + } else { + false + } + } else { + false + } + } + + fn pool() -> &'static MessagePool { + REQUEST_POOL.with(|p| *p) + } +} + +impl RequestHead { + /// Message extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.extensions.borrow() + } + + /// Mutable reference to a the message's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut { + self.extensions.borrow_mut() + } +} + +#[derive(Debug)] +pub struct ResponseHead { + pub version: Version, + pub status: StatusCode, + pub headers: HeaderMap, + pub reason: Option<&'static str>, + pub no_chunking: bool, + pub(crate) ctype: Option, + pub(crate) extensions: RefCell, +} + +impl Default for ResponseHead { + fn default() -> ResponseHead { + ResponseHead { + version: Version::default(), + status: StatusCode::OK, + headers: HeaderMap::with_capacity(16), + reason: None, + no_chunking: false, + ctype: None, + extensions: RefCell::new(Extensions::new()), + } + } +} + +impl ResponseHead { + /// Message extensions + #[inline] + pub fn extensions(&self) -> Ref { + self.extensions.borrow() + } + + /// Mutable reference to a the message's extensions + #[inline] + pub fn extensions_mut(&self) -> RefMut { + self.extensions.borrow_mut() + } +} + +impl Head for ResponseHead { + fn clear(&mut self) { + self.ctype = None; + self.reason = None; + self.no_chunking = false; + self.headers.clear(); + } + + fn headers(&self) -> &HeaderMap { + &self.headers + } + + fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + fn set_connection_type(&mut self, ctype: ConnectionType) { + self.ctype = Some(ctype) + } + + fn connection_type(&self) -> ConnectionType { + if let Some(ct) = self.ctype { + ct + } else if self.version < Version::HTTP_11 { + ConnectionType::Close + } else { + ConnectionType::KeepAlive + } + } + + fn upgrade(&self) -> bool { + self.connection_type() == ConnectionType::Upgrade + } + + fn pool() -> &'static MessagePool { + RESPONSE_POOL.with(|p| *p) + } +} + +impl ResponseHead { + /// Get custom reason for the response + #[inline] + pub fn reason(&self) -> &str { + if let Some(reason) = self.reason { + reason + } else { + self.status + .canonical_reason() + .unwrap_or("") + } + } +} + +pub struct Message { + head: Rc, + pool: &'static MessagePool, +} + +impl Message { + /// Get new message from the pool of objects + pub fn new() -> Self { + T::pool().get_message() + } +} + +impl Clone for Message { + fn clone(&self) -> Self { + Message { + head: self.head.clone(), + pool: self.pool, + } + } +} + +impl std::ops::Deref for Message { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.head.as_ref() + } +} + +impl std::ops::DerefMut for Message { + fn deref_mut(&mut self) -> &mut Self::Target { + Rc::get_mut(&mut self.head).expect("Multiple copies exist") + } +} + +impl Drop for Message { + fn drop(&mut self) { + if Rc::strong_count(&self.head) == 1 { + self.pool.release(self.head.clone()); + } + } +} + +#[doc(hidden)] +/// Request's objects pool +pub struct MessagePool(RefCell>>); + +thread_local!(static REQUEST_POOL: &'static MessagePool = MessagePool::::create()); +thread_local!(static RESPONSE_POOL: &'static MessagePool = MessagePool::::create()); + +impl MessagePool { + fn create() -> &'static MessagePool { + let pool = MessagePool(RefCell::new(VecDeque::with_capacity(128))); + Box::leak(Box::new(pool)) + } + + /// Get message from the pool + #[inline] + fn get_message(&'static self) -> Message { + if let Some(mut msg) = self.0.borrow_mut().pop_front() { + if let Some(r) = Rc::get_mut(&mut msg) { + r.clear(); + } + Message { + head: msg, + pool: self, + } + } else { + Message { + head: Rc::new(T::default()), + pool: self, + } + } + } + + #[inline] + /// Release request instance + fn release(&self, msg: Rc) { + let v = &mut self.0.borrow_mut(); + if v.len() < 128 { + v.push_front(msg); + } + } +} diff --git a/actix-http/src/payload.rs b/actix-http/src/payload.rs new file mode 100644 index 000000000..91e6b5c95 --- /dev/null +++ b/actix-http/src/payload.rs @@ -0,0 +1,64 @@ +use bytes::Bytes; +use futures::{Async, Poll, Stream}; +use h2::RecvStream; + +use crate::error::PayloadError; + +/// Type represent boxed payload +pub type PayloadStream = Box>; + +/// Type represent streaming payload +pub enum Payload { + None, + H1(crate::h1::Payload), + H2(crate::h2::Payload), + Stream(S), +} + +impl From for Payload { + fn from(v: crate::h1::Payload) -> Self { + Payload::H1(v) + } +} + +impl From for Payload { + fn from(v: crate::h2::Payload) -> Self { + Payload::H2(v) + } +} + +impl From for Payload { + fn from(v: RecvStream) -> Self { + Payload::H2(crate::h2::Payload::new(v)) + } +} + +impl From for Payload { + fn from(pl: PayloadStream) -> Self { + Payload::Stream(pl) + } +} + +impl Payload { + /// Takes current payload and replaces it with `None` value + pub fn take(&mut self) -> Payload { + std::mem::replace(self, Payload::None) + } +} + +impl Stream for Payload +where + S: Stream, +{ + type Item = Bytes; + type Error = PayloadError; + + fn poll(&mut self) -> Poll, Self::Error> { + match self { + Payload::None => Ok(Async::Ready(None)), + Payload::H1(ref mut pl) => pl.poll(), + Payload::H2(ref mut pl) => pl.poll(), + Payload::Stream(ref mut pl) => pl.poll(), + } + } +} diff --git a/actix-http/src/request.rs b/actix-http/src/request.rs new file mode 100644 index 000000000..a645c7aeb --- /dev/null +++ b/actix-http/src/request.rs @@ -0,0 +1,169 @@ +use std::cell::{Ref, RefMut}; +use std::fmt; + +use http::{header, HeaderMap, Method, Uri, Version}; + +use crate::extensions::Extensions; +use crate::httpmessage::HttpMessage; +use crate::message::{Message, RequestHead}; +use crate::payload::{Payload, PayloadStream}; + +/// Request +pub struct Request

{ + pub(crate) payload: Payload

, + pub(crate) head: Message, +} + +impl

HttpMessage for Request

{ + type Stream = P; + + #[inline] + fn headers(&self) -> &HeaderMap { + &self.head().headers + } + + /// Request extensions + #[inline] + fn extensions(&self) -> Ref { + self.head.extensions() + } + + /// Mutable reference to a the request's extensions + #[inline] + fn extensions_mut(&self) -> RefMut { + self.head.extensions_mut() + } + + fn take_payload(&mut self) -> Payload

{ + std::mem::replace(&mut self.payload, Payload::None) + } +} + +impl From> for Request { + fn from(head: Message) -> Self { + Request { + head, + payload: Payload::None, + } + } +} + +impl Request { + /// Create new Request instance + pub fn new() -> Request { + Request { + head: Message::new(), + payload: Payload::None, + } + } +} + +impl

Request

{ + /// Create new Request instance + pub fn with_payload(payload: Payload

) -> Request

{ + Request { + payload, + head: Message::new(), + } + } + + /// Create new Request instance + pub fn replace_payload(self, payload: Payload) -> (Request, Payload

) { + let pl = self.payload; + ( + Request { + payload, + head: self.head, + }, + pl, + ) + } + + /// Get request's payload + pub fn take_payload(&mut self) -> Payload

{ + std::mem::replace(&mut self.payload, Payload::None) + } + + /// Split request into request head and payload + pub fn into_parts(self) -> (Message, Payload

) { + (self.head, self.payload) + } + + #[inline] + /// Http message part of the request + pub fn head(&self) -> &RequestHead { + &*self.head + } + + #[inline] + #[doc(hidden)] + /// Mutable reference to a http message part of the request + pub fn head_mut(&mut self) -> &mut RequestHead { + &mut *self.head + } + + /// Mutable reference to the message's headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head_mut().headers + } + + /// Request's uri. + #[inline] + pub fn uri(&self) -> &Uri { + &self.head().uri + } + + /// Mutable reference to the request's uri. + #[inline] + pub fn uri_mut(&mut self) -> &mut Uri { + &mut self.head_mut().uri + } + + /// Read the Request method. + #[inline] + pub fn method(&self) -> &Method { + &self.head().method + } + + /// Read the Request Version. + #[inline] + pub fn version(&self) -> Version { + self.head().version + } + + /// The target path of this Request. + #[inline] + pub fn path(&self) -> &str { + self.head().uri.path() + } + + /// Check if request requires connection upgrade + pub fn upgrade(&self) -> bool { + if let Some(conn) = self.head().headers.get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + return s.to_lowercase().contains("upgrade"); + } + } + self.head().method == Method::CONNECT + } +} + +impl

fmt::Debug for Request

{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!( + f, + "\nRequest {:?} {}:{}", + self.version(), + self.method(), + self.path() + )?; + if let Some(q) = self.uri().query().as_ref() { + writeln!(f, " query: ?{:?}", q)?; + } + writeln!(f, " headers:")?; + for (key, val) in self.headers().iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} diff --git a/actix-http/src/response.rs b/actix-http/src/response.rs new file mode 100644 index 000000000..31c0010ab --- /dev/null +++ b/actix-http/src/response.rs @@ -0,0 +1,1058 @@ +//! Http response +use std::io::Write; +use std::{fmt, str}; + +use bytes::{BufMut, Bytes, BytesMut}; +#[cfg(feature = "cookies")] +use cookie::{Cookie, CookieJar}; +use futures::future::{ok, FutureResult, IntoFuture}; +use futures::Stream; +use http::header::{self, HeaderName, HeaderValue}; +use http::{Error as HttpError, HeaderMap, HttpTryFrom, StatusCode}; +use serde::Serialize; +use serde_json; + +use crate::body::{Body, BodyStream, MessageBody, ResponseBody}; +use crate::error::Error; +use crate::header::{Header, IntoHeaderValue}; +use crate::message::{ConnectionType, Head, Message, ResponseHead}; + +/// An HTTP Response +pub struct Response { + head: Message, + body: ResponseBody, + error: Option, +} + +impl Response { + /// Create http response builder with specific status. + #[inline] + pub fn build(status: StatusCode) -> ResponseBuilder { + ResponseBuilder::new(status) + } + + /// Create http response builder + #[inline] + pub fn build_from>(source: T) -> ResponseBuilder { + source.into() + } + + /// Constructs a response + #[inline] + pub fn new(status: StatusCode) -> Response { + let mut head: Message = Message::new(); + head.status = status; + + Response { + head, + body: ResponseBody::Body(Body::Empty), + error: None, + } + } + + /// Constructs an error response + #[inline] + pub fn from_error(error: Error) -> Response { + let mut resp = error.as_response_error().error_response(); + resp.error = Some(error); + resp + } + + /// Convert response to response with body + pub fn into_body(self) -> Response { + let b = match self.body { + ResponseBody::Body(b) => b, + ResponseBody::Other(b) => b, + }; + Response { + head: self.head, + error: self.error, + body: ResponseBody::Other(b), + } + } +} + +impl Response { + #[inline] + /// Http message part of the response + pub fn head(&self) -> &ResponseHead { + &*self.head + } + + #[inline] + /// Mutable reference to a http message part of the response + pub fn head_mut(&mut self) -> &mut ResponseHead { + &mut *self.head + } + + /// Constructs a response with body + #[inline] + pub fn with_body(status: StatusCode, body: B) -> Response { + let mut head: Message = Message::new(); + head.status = status; + Response { + head, + body: ResponseBody::Body(body), + error: None, + } + } + + /// The source `error` for this response + #[inline] + pub fn error(&self) -> Option<&Error> { + self.error.as_ref() + } + + /// Get the response status code + #[inline] + pub fn status(&self) -> StatusCode { + self.head.status + } + + /// Set the `StatusCode` for this response + #[inline] + pub fn status_mut(&mut self) -> &mut StatusCode { + &mut self.head.status + } + + /// Get the headers from the response + #[inline] + pub fn headers(&self) -> &HeaderMap { + &self.head.headers + } + + /// Get a mutable reference to the headers + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.head.headers + } + + /// Get an iterator for the cookies set by this response + #[inline] + #[cfg(feature = "cookies")] + pub fn cookies(&self) -> CookieIter { + CookieIter { + iter: self.head.headers.get_all(header::SET_COOKIE).iter(), + } + } + + /// Add a cookie to this response + #[inline] + #[cfg(feature = "cookies")] + pub fn add_cookie(&mut self, cookie: &Cookie) -> Result<(), HttpError> { + let h = &mut self.head.headers; + HeaderValue::from_str(&cookie.to_string()) + .map(|c| { + h.append(header::SET_COOKIE, c); + }) + .map_err(|e| e.into()) + } + + /// Remove all cookies with the given name from this response. Returns + /// the number of cookies removed. + #[inline] + #[cfg(feature = "cookies")] + pub fn del_cookie(&mut self, name: &str) -> usize { + let h = &mut self.head.headers; + let vals: Vec = h + .get_all(header::SET_COOKIE) + .iter() + .map(|v| v.to_owned()) + .collect(); + h.remove(header::SET_COOKIE); + + let mut count: usize = 0; + for v in vals { + if let Ok(s) = v.to_str() { + if let Ok(c) = Cookie::parse_encoded(s) { + if c.name() == name { + count += 1; + continue; + } + } + } + h.append(header::SET_COOKIE, v); + } + count + } + + /// Connection upgrade status + #[inline] + pub fn upgrade(&self) -> bool { + self.head.upgrade() + } + + /// Keep-alive status for this connection + pub fn keep_alive(&self) -> bool { + self.head.keep_alive() + } + + /// Get body os this response + #[inline] + pub fn body(&self) -> &ResponseBody { + &self.body + } + + /// Set a body + pub(crate) fn set_body(self, body: B2) -> Response { + Response { + head: self.head, + body: ResponseBody::Body(body), + error: None, + } + } + + /// Drop request's body + pub(crate) fn drop_body(self) -> Response<()> { + Response { + head: self.head, + body: ResponseBody::Body(()), + error: None, + } + } + + /// Set a body and return previous body value + pub(crate) fn replace_body(self, body: B2) -> (Response, ResponseBody) { + ( + Response { + head: self.head, + body: ResponseBody::Body(body), + error: self.error, + }, + self.body, + ) + } + + /// Set a body and return previous body value + pub fn map_body(mut self, f: F) -> Response + where + F: FnOnce(&mut ResponseHead, ResponseBody) -> ResponseBody, + { + let body = f(&mut self.head, self.body); + + Response { + head: self.head, + body: body, + error: self.error, + } + } + + /// Extract response body + pub fn take_body(&mut self) -> ResponseBody { + self.body.take_body() + } +} + +impl fmt::Debug for Response { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = writeln!( + f, + "\nResponse {:?} {}{}", + self.head.version, + self.head.status, + self.head.reason.unwrap_or(""), + ); + let _ = writeln!(f, " headers:"); + for (key, val) in self.head.headers.iter() { + let _ = writeln!(f, " {:?}: {:?}", key, val); + } + let _ = writeln!(f, " body: {:?}", self.body.length()); + res + } +} + +impl IntoFuture for Response { + type Item = Response; + type Error = Error; + type Future = FutureResult; + + fn into_future(self) -> Self::Future { + ok(self) + } +} + +#[cfg(feature = "cookies")] +pub struct CookieIter<'a> { + iter: header::ValueIter<'a, HeaderValue>, +} + +#[cfg(feature = "cookies")] +impl<'a> Iterator for CookieIter<'a> { + type Item = Cookie<'a>; + + #[inline] + fn next(&mut self) -> Option> { + for v in self.iter.by_ref() { + if let Ok(c) = Cookie::parse_encoded(v.to_str().ok()?) { + return Some(c); + } + } + None + } +} + +/// An HTTP response builder +/// +/// This type can be used to construct an instance of `Response` through a +/// builder-like pattern. +pub struct ResponseBuilder { + head: Option>, + err: Option, + #[cfg(feature = "cookies")] + cookies: Option, +} + +impl ResponseBuilder { + /// Create response builder + pub fn new(status: StatusCode) -> Self { + let mut head: Message = Message::new(); + head.status = status; + + ResponseBuilder { + head: Some(head), + err: None, + #[cfg(feature = "cookies")] + cookies: None, + } + } + + /// Set HTTP status code of this response. + #[inline] + pub fn status(&mut self, status: StatusCode) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.status = status; + } + self + } + + /// Set a header. + /// + /// ```rust,ignore + /// # extern crate actix_web; + /// use actix_web::{http, Request, Response, Result}; + /// + /// fn index(req: HttpRequest) -> Result { + /// Ok(Response::Ok() + /// .set(http::header::IfModifiedSince( + /// "Sun, 07 Nov 1994 08:48:37 GMT".parse()?, + /// )) + /// .finish()) + /// } + /// fn main() {} + /// ``` + #[doc(hidden)] + pub fn set(&mut self, hdr: H) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + match hdr.try_into() { + Ok(value) => { + parts.headers.append(H::name(), value); + } + Err(e) => self.err = Some(e.into()), + } + } + self + } + + /// Append a header to existing headers. + /// + /// ```rust,ignore + /// # extern crate actix_web; + /// use actix_web::{http, Request, Response}; + /// + /// fn index(req: HttpRequest) -> Response { + /// Response::Ok() + /// .header("X-TEST", "value") + /// .header(http::header::CONTENT_TYPE, "application/json") + /// .finish() + /// } + /// fn main() {} + /// ``` + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + parts.headers.append(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set a header. + /// + /// ```rust,ignore + /// # extern crate actix_web; + /// use actix_web::{http, Request, Response}; + /// + /// fn index(req: HttpRequest) -> Response { + /// Response::Ok() + /// .set_header("X-TEST", "value") + /// .set_header(http::header::CONTENT_TYPE, "application/json") + /// .finish() + /// } + /// fn main() {} + /// ``` + pub fn set_header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + parts.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set the custom reason for the response. + #[inline] + pub fn reason(&mut self, reason: &'static str) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.reason = Some(reason); + } + self + } + + /// Set connection type to KeepAlive + #[inline] + pub fn keep_alive(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.set_connection_type(ConnectionType::KeepAlive); + } + self + } + + /// Set connection type to Upgrade + #[inline] + pub fn upgrade(&mut self, value: V) -> &mut Self + where + V: IntoHeaderValue, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.set_connection_type(ConnectionType::Upgrade); + } + self.set_header(header::UPGRADE, value) + } + + /// Force close connection, even if it is marked as keep-alive + #[inline] + pub fn force_close(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.set_connection_type(ConnectionType::Close); + } + self + } + + /// Disable chunked transfer encoding for HTTP/1.1 streaming responses. + #[inline] + pub fn no_chunking(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.head, &self.err) { + parts.no_chunking = true; + } + self + } + + /// Set response content type + #[inline] + pub fn content_type(&mut self, value: V) -> &mut Self + where + HeaderValue: HttpTryFrom, + { + if let Some(parts) = parts(&mut self.head, &self.err) { + match HeaderValue::try_from(value) { + Ok(value) => { + parts.headers.insert(header::CONTENT_TYPE, value); + } + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set content length + #[inline] + pub fn content_length(&mut self, len: u64) -> &mut Self { + let mut wrt = BytesMut::new().writer(); + let _ = write!(wrt, "{}", len); + self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + } + + /// Set a cookie + /// + /// ```rust,ignore + /// # extern crate actix_web; + /// use actix_web::{http, HttpRequest, Response, Result}; + /// + /// fn index(req: HttpRequest) -> Response { + /// Response::Ok() + /// .cookie( + /// http::Cookie::build("name", "value") + /// .domain("www.rust-lang.org") + /// .path("/") + /// .secure(true) + /// .http_only(true) + /// .finish(), + /// ) + /// .finish() + /// } + /// ``` + #[cfg(feature = "cookies")] + pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Remove cookie + /// + /// ```rust,ignore + /// # extern crate actix_web; + /// use actix_web::{http, HttpRequest, Response, Result}; + /// + /// fn index(req: &HttpRequest) -> Response { + /// let mut builder = Response::Ok(); + /// + /// if let Some(ref cookie) = req.cookie("name") { + /// builder.del_cookie(cookie); + /// } + /// + /// builder.finish() + /// } + /// ``` + #[cfg(feature = "cookies")] + pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { + { + if self.cookies.is_none() { + self.cookies = Some(CookieJar::new()) + } + let jar = self.cookies.as_mut().unwrap(); + let cookie = cookie.clone().into_owned(); + jar.add_original(cookie.clone()); + jar.remove(cookie); + } + self + } + + /// This method calls provided closure with builder reference if value is + /// true. + pub fn if_true(&mut self, value: bool, f: F) -> &mut Self + where + F: FnOnce(&mut ResponseBuilder), + { + if value { + f(self); + } + self + } + + /// This method calls provided closure with builder reference if value is + /// Some. + pub fn if_some(&mut self, value: Option, f: F) -> &mut Self + where + F: FnOnce(T, &mut ResponseBuilder), + { + if let Some(val) = value { + f(val, self); + } + self + } + + /// Set a body and generate `Response`. + /// + /// `ResponseBuilder` can not be used after this call. + pub fn body>(&mut self, body: B) -> Response { + self.message_body(body.into()) + } + + /// Set a body and generate `Response`. + /// + /// `ResponseBuilder` can not be used after this call. + pub fn message_body(&mut self, body: B) -> Response { + if let Some(e) = self.err.take() { + return Response::from(Error::from(e)).into_body(); + } + + #[allow(unused_mut)] + let mut response = self.head.take().expect("cannot reuse response builder"); + + #[cfg(feature = "cookies")] + { + if let Some(ref jar) = self.cookies { + for cookie in jar.delta() { + match HeaderValue::from_str(&cookie.to_string()) { + Ok(val) => { + let _ = response.headers.append(header::SET_COOKIE, val); + } + Err(e) => return Response::from(Error::from(e)).into_body(), + }; + } + } + } + + Response { + head: response, + body: ResponseBody::Body(body), + error: None, + } + } + + #[inline] + /// Set a streaming body and generate `Response`. + /// + /// `ResponseBuilder` can not be used after this call. + pub fn streaming(&mut self, stream: S) -> Response + where + S: Stream + 'static, + E: Into + 'static, + { + self.body(Body::from_message(BodyStream::new(stream))) + } + + /// Set a json body and generate `Response` + /// + /// `ResponseBuilder` can not be used after this call. + pub fn json(&mut self, value: T) -> Response { + self.json2(&value) + } + + /// Set a json body and generate `Response` + /// + /// `ResponseBuilder` can not be used after this call. + pub fn json2(&mut self, value: &T) -> Response { + match serde_json::to_string(value) { + Ok(body) => { + let contains = if let Some(parts) = parts(&mut self.head, &self.err) { + parts.headers.contains_key(header::CONTENT_TYPE) + } else { + true + }; + if !contains { + self.header(header::CONTENT_TYPE, "application/json"); + } + + self.body(Body::from(body)) + } + Err(e) => Error::from(e).into(), + } + } + + #[inline] + /// Set an empty body and generate `Response` + /// + /// `ResponseBuilder` can not be used after this call. + pub fn finish(&mut self) -> Response { + self.body(Body::Empty) + } + + /// This method construct new `ResponseBuilder` + pub fn take(&mut self) -> ResponseBuilder { + ResponseBuilder { + head: self.head.take(), + err: self.err.take(), + #[cfg(feature = "cookies")] + cookies: self.cookies.take(), + } + } +} + +#[inline] +fn parts<'a>( + parts: &'a mut Option>, + err: &Option, +) -> Option<&'a mut Message> { + if err.is_some() { + return None; + } + parts.as_mut() +} + +/// Convert `Response` to a `ResponseBuilder`. Body get dropped. +impl From> for ResponseBuilder { + fn from(res: Response) -> ResponseBuilder { + // If this response has cookies, load them into a jar + #[cfg(feature = "cookies")] + let mut jar: Option = None; + #[cfg(feature = "cookies")] + for c in res.cookies() { + if let Some(ref mut j) = jar { + j.add_original(c.into_owned()); + } else { + let mut j = CookieJar::new(); + j.add_original(c.into_owned()); + jar = Some(j); + } + } + + ResponseBuilder { + head: Some(res.head), + err: None, + #[cfg(feature = "cookies")] + cookies: jar, + } + } +} + +/// Convert `ResponseHead` to a `ResponseBuilder` +impl<'a> From<&'a ResponseHead> for ResponseBuilder { + fn from(head: &'a ResponseHead) -> ResponseBuilder { + // If this response has cookies, load them into a jar + #[cfg(feature = "cookies")] + let mut jar: Option = None; + + #[cfg(feature = "cookies")] + { + let cookies = CookieIter { + iter: head.headers.get_all(header::SET_COOKIE).iter(), + }; + for c in cookies { + if let Some(ref mut j) = jar { + j.add_original(c.into_owned()); + } else { + let mut j = CookieJar::new(); + j.add_original(c.into_owned()); + jar = Some(j); + } + } + } + + let mut msg: Message = Message::new(); + msg.version = head.version; + msg.status = head.status; + msg.reason = head.reason; + msg.headers = head.headers.clone(); + msg.no_chunking = head.no_chunking; + + ResponseBuilder { + head: Some(msg), + err: None, + #[cfg(feature = "cookies")] + cookies: jar, + } + } +} + +impl IntoFuture for ResponseBuilder { + type Item = Response; + type Error = Error; + type Future = FutureResult; + + fn into_future(mut self) -> Self::Future { + ok(self.finish()) + } +} + +/// Helper converters +impl, E: Into> From> for Response { + fn from(res: Result) -> Self { + match res { + Ok(val) => val.into(), + Err(err) => err.into().into(), + } + } +} + +impl From for Response { + fn from(mut builder: ResponseBuilder) -> Self { + builder.finish() + } +} + +impl From<&'static str> for Response { + fn from(val: &'static str) -> Self { + Response::Ok() + .content_type("text/plain; charset=utf-8") + .body(val) + } +} + +impl From<&'static [u8]> for Response { + fn from(val: &'static [u8]) -> Self { + Response::Ok() + .content_type("application/octet-stream") + .body(val) + } +} + +impl From for Response { + fn from(val: String) -> Self { + Response::Ok() + .content_type("text/plain; charset=utf-8") + .body(val) + } +} + +impl<'a> From<&'a String> for Response { + fn from(val: &'a String) -> Self { + Response::Ok() + .content_type("text/plain; charset=utf-8") + .body(val) + } +} + +impl From for Response { + fn from(val: Bytes) -> Self { + Response::Ok() + .content_type("application/octet-stream") + .body(val) + } +} + +impl From for Response { + fn from(val: BytesMut) -> Self { + Response::Ok() + .content_type("application/octet-stream") + .body(val) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::body::Body; + use crate::http::header::{HeaderValue, CONTENT_TYPE, COOKIE}; + + #[test] + fn test_debug() { + let resp = Response::Ok() + .header(COOKIE, HeaderValue::from_static("cookie1=value1; ")) + .header(COOKIE, HeaderValue::from_static("cookie2=value2; ")) + .finish(); + let dbg = format!("{:?}", resp); + assert!(dbg.contains("Response")); + } + + #[test] + #[cfg(feature = "cookies")] + fn test_response_cookies() { + use crate::httpmessage::HttpMessage; + + let req = crate::test::TestRequest::default() + .header(COOKIE, "cookie1=value1") + .header(COOKIE, "cookie2=value2") + .finish(); + let cookies = req.cookies().unwrap(); + + let resp = Response::Ok() + .cookie( + crate::http::Cookie::build("name", "value") + .domain("www.rust-lang.org") + .path("/test") + .http_only(true) + .max_age(time::Duration::days(1)) + .finish(), + ) + .del_cookie(&cookies[0]) + .finish(); + + let mut val: Vec<_> = resp + .headers() + .get_all("Set-Cookie") + .iter() + .map(|v| v.to_str().unwrap().to_owned()) + .collect(); + val.sort(); + assert!(val[0].starts_with("cookie1=; Max-Age=0;")); + assert_eq!( + val[1], + "name=value; HttpOnly; Path=/test; Domain=www.rust-lang.org; Max-Age=86400" + ); + } + + #[test] + #[cfg(feature = "cookies")] + fn test_update_response_cookies() { + let mut r = Response::Ok() + .cookie(crate::http::Cookie::new("original", "val100")) + .finish(); + + r.add_cookie(&crate::http::Cookie::new("cookie2", "val200")) + .unwrap(); + r.add_cookie(&crate::http::Cookie::new("cookie2", "val250")) + .unwrap(); + r.add_cookie(&crate::http::Cookie::new("cookie3", "val300")) + .unwrap(); + + assert_eq!(r.cookies().count(), 4); + r.del_cookie("cookie2"); + + let mut iter = r.cookies(); + let v = iter.next().unwrap(); + assert_eq!((v.name(), v.value()), ("original", "val100")); + let v = iter.next().unwrap(); + assert_eq!((v.name(), v.value()), ("cookie3", "val300")); + } + + #[test] + fn test_basic_builder() { + let resp = Response::Ok().header("X-TEST", "value").finish(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[test] + fn test_upgrade() { + let resp = Response::build(StatusCode::OK) + .upgrade("websocket") + .finish(); + assert!(resp.upgrade()); + assert_eq!( + resp.headers().get(header::UPGRADE).unwrap(), + HeaderValue::from_static("websocket") + ); + } + + #[test] + fn test_force_close() { + let resp = Response::build(StatusCode::OK).force_close().finish(); + assert!(!resp.keep_alive()) + } + + #[test] + fn test_content_type() { + let resp = Response::build(StatusCode::OK) + .content_type("text/plain") + .body(Body::Empty); + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "text/plain") + } + + #[test] + fn test_json() { + let resp = Response::build(StatusCode::OK).json(vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("application/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_json_ct() { + let resp = Response::build(StatusCode::OK) + .header(CONTENT_TYPE, "text/json") + .json(vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("text/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_json2() { + let resp = Response::build(StatusCode::OK).json2(&vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("application/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_json2_ct() { + let resp = Response::build(StatusCode::OK) + .header(CONTENT_TYPE, "text/json") + .json2(&vec!["v1", "v2", "v3"]); + let ct = resp.headers().get(CONTENT_TYPE).unwrap(); + assert_eq!(ct, HeaderValue::from_static("text/json")); + assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + } + + #[test] + fn test_into_response() { + let resp: Response = "test".into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let resp: Response = b"test".as_ref().into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let resp: Response = "test".to_owned().into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let resp: Response = (&"test".to_owned()).into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("text/plain; charset=utf-8") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let b = Bytes::from_static(b"test"); + let resp: Response = b.into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let b = Bytes::from_static(b"test"); + let resp: Response = b.into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + + let b = BytesMut::from("test"); + let resp: Response = b.into(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get(CONTENT_TYPE).unwrap(), + HeaderValue::from_static("application/octet-stream") + ); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!(resp.body().get_ref(), b"test"); + } + + #[test] + #[cfg(feature = "cookies")] + fn test_into_builder() { + let mut resp: Response = "test".into(); + assert_eq!(resp.status(), StatusCode::OK); + + resp.add_cookie(&crate::http::Cookie::new("cookie1", "val100")) + .unwrap(); + + let mut builder: ResponseBuilder = resp.into(); + let resp = builder.status(StatusCode::BAD_REQUEST).finish(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let cookie = resp.cookies().next().unwrap(); + assert_eq!((cookie.name(), cookie.value()), ("cookie1", "val100")); + } +} diff --git a/actix-http/src/service/mod.rs b/actix-http/src/service/mod.rs new file mode 100644 index 000000000..25e95bf60 --- /dev/null +++ b/actix-http/src/service/mod.rs @@ -0,0 +1,5 @@ +mod senderror; +mod service; + +pub use self::senderror::{SendError, SendResponse}; +pub use self::service::HttpService; diff --git a/actix-http/src/service/senderror.rs b/actix-http/src/service/senderror.rs new file mode 100644 index 000000000..44d362593 --- /dev/null +++ b/actix-http/src/service/senderror.rs @@ -0,0 +1,241 @@ +use std::marker::PhantomData; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::{NewService, Service}; +use futures::future::{ok, Either, FutureResult}; +use futures::{Async, Future, Poll, Sink}; + +use crate::body::{BodyLength, MessageBody, ResponseBody}; +use crate::error::{Error, ResponseError}; +use crate::h1::{Codec, Message}; +use crate::response::Response; + +pub struct SendError(PhantomData<(T, R, E)>); + +impl Default for SendError +where + T: AsyncRead + AsyncWrite, + E: ResponseError, +{ + fn default() -> Self { + SendError(PhantomData) + } +} + +impl NewService for SendError +where + T: AsyncRead + AsyncWrite, + E: ResponseError, +{ + type Request = Result)>; + type Response = R; + type Error = (E, Framed); + type InitError = (); + type Service = SendError; + type Future = FutureResult; + + fn new_service(&self, _: &()) -> Self::Future { + ok(SendError(PhantomData)) + } +} + +impl Service for SendError +where + T: AsyncRead + AsyncWrite, + E: ResponseError, +{ + type Request = Result)>; + type Response = R; + type Error = (E, Framed); + type Future = Either)>, SendErrorFut>; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, req: Result)>) -> Self::Future { + match req { + Ok(r) => Either::A(ok(r)), + Err((e, framed)) => { + let res = e.error_response().set_body(format!("{}", e)); + let (res, _body) = res.replace_body(()); + Either::B(SendErrorFut { + framed: Some(framed), + res: Some((res, BodyLength::Empty).into()), + err: Some(e), + _t: PhantomData, + }) + } + } + } +} + +pub struct SendErrorFut { + res: Option, BodyLength)>>, + framed: Option>, + err: Option, + _t: PhantomData, +} + +impl Future for SendErrorFut +where + E: ResponseError, + T: AsyncRead + AsyncWrite, +{ + type Item = R; + type Error = (E, Framed); + + fn poll(&mut self) -> Poll { + if let Some(res) = self.res.take() { + if self.framed.as_mut().unwrap().force_send(res).is_err() { + return Err((self.err.take().unwrap(), self.framed.take().unwrap())); + } + } + match self.framed.as_mut().unwrap().poll_complete() { + Ok(Async::Ready(_)) => { + Err((self.err.take().unwrap(), self.framed.take().unwrap())) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(_) => Err((self.err.take().unwrap(), self.framed.take().unwrap())), + } + } +} + +pub struct SendResponse(PhantomData<(T, B)>); + +impl Default for SendResponse { + fn default() -> Self { + SendResponse(PhantomData) + } +} + +impl SendResponse +where + T: AsyncRead + AsyncWrite, + B: MessageBody, +{ + pub fn send( + framed: Framed, + res: Response, + ) -> impl Future, Error = Error> { + // extract body from response + let (res, body) = res.replace_body(()); + + // write response + SendResponseFut { + res: Some(Message::Item((res, body.length()))), + body: Some(body), + framed: Some(framed), + } + } +} + +impl NewService for SendResponse +where + T: AsyncRead + AsyncWrite, + B: MessageBody, +{ + type Request = (Response, Framed); + type Response = Framed; + type Error = Error; + type InitError = (); + type Service = SendResponse; + type Future = FutureResult; + + fn new_service(&self, _: &()) -> Self::Future { + ok(SendResponse(PhantomData)) + } +} + +impl Service for SendResponse +where + T: AsyncRead + AsyncWrite, + B: MessageBody, +{ + type Request = (Response, Framed); + type Response = Framed; + type Error = Error; + type Future = SendResponseFut; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, (res, framed): (Response, Framed)) -> Self::Future { + let (res, body) = res.replace_body(()); + SendResponseFut { + res: Some(Message::Item((res, body.length()))), + body: Some(body), + framed: Some(framed), + } + } +} + +pub struct SendResponseFut { + res: Option, BodyLength)>>, + body: Option>, + framed: Option>, +} + +impl Future for SendResponseFut +where + T: AsyncRead + AsyncWrite, + B: MessageBody, +{ + type Item = Framed; + type Error = Error; + + fn poll(&mut self) -> Poll { + loop { + let mut body_ready = self.body.is_some(); + let framed = self.framed.as_mut().unwrap(); + + // send body + if self.res.is_none() && self.body.is_some() { + while body_ready && self.body.is_some() && !framed.is_write_buf_full() { + match self.body.as_mut().unwrap().poll_next()? { + Async::Ready(item) => { + // body is done + if item.is_none() { + let _ = self.body.take(); + } + framed.force_send(Message::Chunk(item))?; + } + Async::NotReady => body_ready = false, + } + } + } + + // flush write buffer + if !framed.is_write_buf_empty() { + match framed.poll_complete()? { + Async::Ready(_) => { + if body_ready { + continue; + } else { + return Ok(Async::NotReady); + } + } + Async::NotReady => return Ok(Async::NotReady), + } + } + + // send response + if let Some(res) = self.res.take() { + framed.force_send(res)?; + continue; + } + + if self.body.is_some() { + if body_ready { + continue; + } else { + return Ok(Async::NotReady); + } + } else { + break; + } + } + Ok(Async::Ready(self.framed.take().unwrap())) + } +} diff --git a/actix-http/src/service/service.rs b/actix-http/src/service/service.rs new file mode 100644 index 000000000..0bc1634d8 --- /dev/null +++ b/actix-http/src/service/service.rs @@ -0,0 +1,344 @@ +use std::fmt::Debug; +use std::marker::PhantomData; +use std::{fmt, io}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed, FramedParts}; +use actix_server_config::{Io as ServerIo, Protocol, ServerConfig as SrvConfig}; +use actix_service::{IntoNewService, NewService, Service}; +use actix_utils::cloneable::CloneableService; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::{try_ready, Async, Future, IntoFuture, Poll}; +use h2::server::{self, Handshake}; +use log::error; + +use crate::body::MessageBody; +use crate::builder::HttpServiceBuilder; +use crate::config::{KeepAlive, ServiceConfig}; +use crate::error::DispatchError; +use crate::request::Request; +use crate::response::Response; +use crate::{h1, h2::Dispatcher}; + +/// `NewService` HTTP1.1/HTTP2 transport implementation +pub struct HttpService { + srv: S, + cfg: ServiceConfig, + _t: PhantomData<(T, P, B)>, +} + +impl HttpService +where + S: NewService, + S::Service: 'static, + S::Error: Debug + 'static, + S::Response: Into>, + B: MessageBody + 'static, +{ + /// Create builder for `HttpService` instance. + pub fn build() -> HttpServiceBuilder { + HttpServiceBuilder::new() + } +} + +impl HttpService +where + S: NewService, + S::Service: 'static, + S::Error: Debug + 'static, + S::Response: Into>, + B: MessageBody + 'static, +{ + /// Create new `HttpService` instance. + pub fn new>(service: F) -> Self { + let cfg = ServiceConfig::new(KeepAlive::Timeout(5), 5000, 0); + + HttpService { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } + + /// Create new `HttpService` instance with config. + pub(crate) fn with_config>( + cfg: ServiceConfig, + service: F, + ) -> Self { + HttpService { + cfg, + srv: service.into_new_service(), + _t: PhantomData, + } + } +} + +impl NewService for HttpService +where + T: AsyncRead + AsyncWrite + 'static, + S: NewService, + S::Service: 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Request = ServerIo; + type Response = (); + type Error = DispatchError; + type InitError = S::InitError; + type Service = HttpServiceHandler; + type Future = HttpServiceResponse; + + fn new_service(&self, cfg: &SrvConfig) -> Self::Future { + HttpServiceResponse { + fut: self.srv.new_service(cfg).into_future(), + cfg: Some(self.cfg.clone()), + _t: PhantomData, + } + } +} + +#[doc(hidden)] +pub struct HttpServiceResponse, B> { + fut: ::Future, + cfg: Option, + _t: PhantomData<(T, P, B)>, +} + +impl Future for HttpServiceResponse +where + T: AsyncRead + AsyncWrite, + S: NewService, + S::Service: 'static, + S::Response: Into>, + S::Error: Debug, + B: MessageBody + 'static, +{ + type Item = HttpServiceHandler; + type Error = S::InitError; + + fn poll(&mut self) -> Poll { + let service = try_ready!(self.fut.poll()); + Ok(Async::Ready(HttpServiceHandler::new( + self.cfg.take().unwrap(), + service, + ))) + } +} + +/// `Service` implementation for http transport +pub struct HttpServiceHandler { + srv: CloneableService, + cfg: ServiceConfig, + _t: PhantomData<(T, P, B)>, +} + +impl HttpServiceHandler +where + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + fn new(cfg: ServiceConfig, srv: S) -> HttpServiceHandler { + HttpServiceHandler { + cfg, + srv: CloneableService::new(srv), + _t: PhantomData, + } + } +} + +impl Service for HttpServiceHandler +where + T: AsyncRead + AsyncWrite + 'static, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + type Request = ServerIo; + type Response = (); + type Error = DispatchError; + type Future = HttpServiceHandlerResponse; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.srv.poll_ready().map_err(|e| { + error!("Service readiness error: {:?}", e); + DispatchError::Service + }) + } + + fn call(&mut self, req: Self::Request) -> Self::Future { + let (io, _, proto) = req.into_parts(); + match proto { + Protocol::Http2 => { + let io = Io { + inner: io, + unread: None, + }; + HttpServiceHandlerResponse { + state: State::Handshake(Some(( + server::handshake(io), + self.cfg.clone(), + self.srv.clone(), + ))), + } + } + Protocol::Http10 | Protocol::Http11 => HttpServiceHandlerResponse { + state: State::H1(h1::Dispatcher::new( + io, + self.cfg.clone(), + self.srv.clone(), + )), + }, + _ => HttpServiceHandlerResponse { + state: State::Unknown(Some(( + io, + BytesMut::with_capacity(14), + self.cfg.clone(), + self.srv.clone(), + ))), + }, + } + } +} + +enum State + 'static, B: MessageBody> +where + S::Error: fmt::Debug, + T: AsyncRead + AsyncWrite + 'static, +{ + H1(h1::Dispatcher), + H2(Dispatcher, S, B>), + Unknown(Option<(T, BytesMut, ServiceConfig, CloneableService)>), + Handshake(Option<(Handshake, Bytes>, ServiceConfig, CloneableService)>), +} + +pub struct HttpServiceHandlerResponse +where + T: AsyncRead + AsyncWrite + 'static, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody + 'static, +{ + state: State, +} + +const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; + +impl Future for HttpServiceHandlerResponse +where + T: AsyncRead + AsyncWrite, + S: Service + 'static, + S::Error: Debug, + S::Response: Into>, + B: MessageBody, +{ + type Item = (); + type Error = DispatchError; + + fn poll(&mut self) -> Poll { + match self.state { + State::H1(ref mut disp) => disp.poll(), + State::H2(ref mut disp) => disp.poll(), + State::Unknown(ref mut data) => { + if let Some(ref mut item) = data { + loop { + unsafe { + let b = item.1.bytes_mut(); + let n = { try_ready!(item.0.poll_read(b)) }; + item.1.advance_mut(n); + if item.1.len() >= HTTP2_PREFACE.len() { + break; + } + } + } + } else { + panic!() + } + let (io, buf, cfg, srv) = data.take().unwrap(); + if buf[..14] == HTTP2_PREFACE[..] { + let io = Io { + inner: io, + unread: Some(buf), + }; + self.state = + State::Handshake(Some((server::handshake(io), cfg, srv))); + } else { + let framed = Framed::from_parts(FramedParts::with_read_buf( + io, + h1::Codec::new(cfg.clone()), + buf, + )); + self.state = + State::H1(h1::Dispatcher::with_timeout(framed, cfg, None, srv)) + } + self.poll() + } + State::Handshake(ref mut data) => { + let conn = if let Some(ref mut item) = data { + match item.0.poll() { + Ok(Async::Ready(conn)) => conn, + Ok(Async::NotReady) => return Ok(Async::NotReady), + Err(err) => { + trace!("H2 handshake error: {}", err); + return Err(err.into()); + } + } + } else { + panic!() + }; + let (_, cfg, srv) = data.take().unwrap(); + self.state = State::H2(Dispatcher::new(srv, conn, cfg, None)); + self.poll() + } + } + } +} + +/// Wrapper for `AsyncRead + AsyncWrite` types +struct Io { + unread: Option, + inner: T, +} + +impl io::Read for Io { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if let Some(mut bytes) = self.unread.take() { + let size = std::cmp::min(buf.len(), bytes.len()); + buf[..size].copy_from_slice(&bytes[..size]); + if bytes.len() > size { + bytes.split_to(size); + self.unread = Some(bytes); + } + Ok(size) + } else { + self.inner.read(buf) + } + } +} + +impl io::Write for Io { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.write(buf) + } + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() + } +} + +impl AsyncRead for Io { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) + } +} + +impl AsyncWrite for Io { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.inner.shutdown() + } + fn write_buf(&mut self, buf: &mut B) -> Poll { + self.inner.write_buf(buf) + } +} diff --git a/actix-http/src/test.rs b/actix-http/src/test.rs new file mode 100644 index 000000000..2d4b3d0f9 --- /dev/null +++ b/actix-http/src/test.rs @@ -0,0 +1,192 @@ +//! Test Various helpers for Actix applications to use during testing. +use std::str::FromStr; + +use bytes::Bytes; +#[cfg(feature = "cookies")] +use cookie::{Cookie, CookieJar}; +use http::header::HeaderName; +use http::{HeaderMap, HttpTryFrom, Method, Uri, Version}; + +use crate::header::{Header, IntoHeaderValue}; +use crate::payload::Payload; +use crate::Request; + +/// Test `Request` builder +/// +/// ```rust,ignore +/// # extern crate http; +/// # extern crate actix_web; +/// # use http::{header, StatusCode}; +/// # use actix_web::*; +/// use actix_web::test::TestRequest; +/// +/// fn index(req: &HttpRequest) -> Response { +/// if let Some(hdr) = req.headers().get(header::CONTENT_TYPE) { +/// Response::Ok().into() +/// } else { +/// Response::BadRequest().into() +/// } +/// } +/// +/// fn main() { +/// let resp = TestRequest::with_header("content-type", "text/plain") +/// .run(&index) +/// .unwrap(); +/// assert_eq!(resp.status(), StatusCode::OK); +/// +/// let resp = TestRequest::default().run(&index).unwrap(); +/// assert_eq!(resp.status(), StatusCode::BAD_REQUEST); +/// } +/// ``` +pub struct TestRequest(Option); + +struct Inner { + version: Version, + method: Method, + uri: Uri, + headers: HeaderMap, + #[cfg(feature = "cookies")] + cookies: CookieJar, + payload: Option, +} + +impl Default for TestRequest { + fn default() -> TestRequest { + TestRequest(Some(Inner { + method: Method::GET, + uri: Uri::from_str("/").unwrap(), + version: Version::HTTP_11, + headers: HeaderMap::new(), + #[cfg(feature = "cookies")] + cookies: CookieJar::new(), + payload: None, + })) + } +} + +impl TestRequest { + /// Create TestRequest and set request uri + pub fn with_uri(path: &str) -> TestRequest { + TestRequest::default().uri(path).take() + } + + /// Create TestRequest and set header + pub fn with_hdr(hdr: H) -> TestRequest { + TestRequest::default().set(hdr).take() + } + + /// Create TestRequest and set header + pub fn with_header(key: K, value: V) -> TestRequest + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + TestRequest::default().header(key, value).take() + } + + /// Set HTTP version of this request + pub fn version(&mut self, ver: Version) -> &mut Self { + parts(&mut self.0).version = ver; + self + } + + /// Set HTTP method of this request + pub fn method(&mut self, meth: Method) -> &mut Self { + parts(&mut self.0).method = meth; + self + } + + /// Set HTTP Uri of this request + pub fn uri(&mut self, path: &str) -> &mut Self { + parts(&mut self.0).uri = Uri::from_str(path).unwrap(); + self + } + + /// Set a header + pub fn set(&mut self, hdr: H) -> &mut Self { + if let Ok(value) = hdr.try_into() { + parts(&mut self.0).headers.append(H::name(), value); + return self; + } + panic!("Can not set header"); + } + + /// Set a header + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + if let Ok(key) = HeaderName::try_from(key) { + if let Ok(value) = value.try_into() { + parts(&mut self.0).headers.append(key, value); + return self; + } + } + panic!("Can not create header"); + } + + /// Set cookie for this request + #[cfg(feature = "cookies")] + pub fn cookie<'a>(&mut self, cookie: Cookie<'a>) -> &mut Self { + parts(&mut self.0).cookies.add(cookie.into_owned()); + self + } + + /// Set request payload + pub fn set_payload>(&mut self, data: B) -> &mut Self { + let mut payload = crate::h1::Payload::empty(); + payload.unread_data(data.into()); + parts(&mut self.0).payload = Some(payload.into()); + self + } + + pub fn take(&mut self) -> TestRequest { + TestRequest(self.0.take()) + } + + /// Complete request creation and generate `Request` instance + pub fn finish(&mut self) -> Request { + let inner = self.0.take().expect("cannot reuse test request builder");; + + let mut req = if let Some(pl) = inner.payload { + Request::with_payload(pl) + } else { + Request::with_payload(crate::h1::Payload::empty().into()) + }; + + let head = req.head_mut(); + head.uri = inner.uri; + head.method = inner.method; + head.version = inner.version; + head.headers = inner.headers; + + #[cfg(feature = "cookies")] + { + use std::fmt::Write as FmtWrite; + + use http::header::{self, HeaderValue}; + use percent_encoding::{percent_encode, USERINFO_ENCODE_SET}; + + let mut cookie = String::new(); + for c in inner.cookies.delta() { + let name = percent_encode(c.name().as_bytes(), USERINFO_ENCODE_SET); + let value = percent_encode(c.value().as_bytes(), USERINFO_ENCODE_SET); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + if !cookie.is_empty() { + head.headers.insert( + header::COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + } + + req + } +} + +#[inline] +fn parts<'a>(parts: &'a mut Option) -> &'a mut Inner { + parts.as_mut().expect("cannot reuse test request builder") +} diff --git a/actix-http/src/ws/client/connect.rs b/actix-http/src/ws/client/connect.rs new file mode 100644 index 000000000..2760967e0 --- /dev/null +++ b/actix-http/src/ws/client/connect.rs @@ -0,0 +1,109 @@ +//! Http client request +use std::str; + +#[cfg(feature = "cookies")] +use cookie::Cookie; +use http::header::{HeaderName, HeaderValue}; +use http::{Error as HttpError, HttpTryFrom, Uri}; + +use super::ClientError; +use crate::header::IntoHeaderValue; +use crate::message::RequestHead; + +/// `WebSocket` connection +pub struct Connect { + pub(super) head: RequestHead, + pub(super) err: Option, + pub(super) http_err: Option, + pub(super) origin: Option, + pub(super) protocols: Option, + pub(super) max_size: usize, + pub(super) server_mode: bool, +} + +impl Connect { + /// Create new websocket connection + pub fn new>(uri: S) -> Connect { + let mut cl = Connect { + head: RequestHead::default(), + err: None, + http_err: None, + origin: None, + protocols: None, + max_size: 65_536, + server_mode: false, + }; + + match Uri::try_from(uri.as_ref()) { + Ok(uri) => cl.head.uri = uri, + Err(e) => cl.http_err = Some(e.into()), + } + + cl + } + + /// Set supported websocket protocols + pub fn protocols(mut self, protos: U) -> Self + where + U: IntoIterator + 'static, + V: AsRef, + { + let mut protos = protos + .into_iter() + .fold(String::new(), |acc, s| acc + s.as_ref() + ","); + protos.pop(); + self.protocols = Some(protos); + self + } + + // #[cfg(feature = "cookies")] + // /// Set cookie for handshake request + // pub fn cookie(mut self, cookie: Cookie) -> Self { + // self.request.cookie(cookie); + // self + // } + + /// Set request Origin + pub fn origin(mut self, origin: V) -> Self + where + HeaderValue: HttpTryFrom, + { + match HeaderValue::try_from(origin) { + Ok(value) => self.origin = Some(value), + Err(e) => self.http_err = Some(e.into()), + } + self + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_frame_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Disable payload masking. By default ws client masks frame payload. + pub fn server_mode(mut self) -> Self { + self.server_mode = true; + self + } + + /// Set request header + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + self.head.headers.append(key, value); + } + Err(e) => self.http_err = Some(e.into()), + }, + Err(e) => self.http_err = Some(e.into()), + } + self + } +} diff --git a/actix-http/src/ws/client/error.rs b/actix-http/src/ws/client/error.rs new file mode 100644 index 000000000..ae1e39967 --- /dev/null +++ b/actix-http/src/ws/client/error.rs @@ -0,0 +1,53 @@ +//! Http client request +use std::io; + +use actix_connect::ConnectError; +use derive_more::{Display, From}; +use http::{header::HeaderValue, Error as HttpError, StatusCode}; + +use crate::error::ParseError; +use crate::ws::ProtocolError; + +/// Websocket client error +#[derive(Debug, Display, From)] +pub enum ClientError { + /// Invalid url + #[display(fmt = "Invalid url")] + InvalidUrl, + /// Invalid response status + #[display(fmt = "Invalid response status")] + InvalidResponseStatus(StatusCode), + /// Invalid upgrade header + #[display(fmt = "Invalid upgrade header")] + InvalidUpgradeHeader, + /// Invalid connection header + #[display(fmt = "Invalid connection header")] + InvalidConnectionHeader(HeaderValue), + /// Missing CONNECTION header + #[display(fmt = "Missing CONNECTION header")] + MissingConnectionHeader, + /// Missing SEC-WEBSOCKET-ACCEPT header + #[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")] + MissingWebSocketAcceptHeader, + /// Invalid challenge response + #[display(fmt = "Invalid challenge response")] + InvalidChallengeResponse(String, HeaderValue), + /// Http parsing error + #[display(fmt = "Http parsing error")] + Http(HttpError), + /// Response parsing error + #[display(fmt = "Response parsing error: {}", _0)] + ParseError(ParseError), + /// Protocol error + #[display(fmt = "{}", _0)] + Protocol(ProtocolError), + /// Connect error + #[display(fmt = "Connector error: {:?}", _0)] + Connect(ConnectError), + /// IO Error + #[display(fmt = "{}", _0)] + Io(io::Error), + /// "Disconnected" + #[display(fmt = "Disconnected")] + Disconnected, +} diff --git a/actix-http/src/ws/client/mod.rs b/actix-http/src/ws/client/mod.rs new file mode 100644 index 000000000..a5c221967 --- /dev/null +++ b/actix-http/src/ws/client/mod.rs @@ -0,0 +1,48 @@ +mod connect; +mod error; +mod service; + +pub use self::connect::Connect; +pub use self::error::ClientError; +pub use self::service::Client; + +#[derive(PartialEq, Hash, Debug, Clone, Copy)] +pub(crate) enum Protocol { + Http, + Https, + Ws, + Wss, +} + +impl Protocol { + fn from(s: &str) -> Option { + match s { + "http" => Some(Protocol::Http), + "https" => Some(Protocol::Https), + "ws" => Some(Protocol::Ws), + "wss" => Some(Protocol::Wss), + _ => None, + } + } + + // fn is_http(self) -> bool { + // match self { + // Protocol::Https | Protocol::Http => true, + // _ => false, + // } + // } + + // fn is_secure(self) -> bool { + // match self { + // Protocol::Https | Protocol::Wss => true, + // _ => false, + // } + // } + + fn port(self) -> u16 { + match self { + Protocol::Http | Protocol::Ws => 80, + Protocol::Https | Protocol::Wss => 443, + } + } +} diff --git a/actix-http/src/ws/client/service.rs b/actix-http/src/ws/client/service.rs new file mode 100644 index 000000000..8a0840f90 --- /dev/null +++ b/actix-http/src/ws/client/service.rs @@ -0,0 +1,272 @@ +//! websockets client +use std::marker::PhantomData; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_connect::{default_connector, Connect as TcpConnect, ConnectError}; +use actix_service::{apply_fn, Service}; +use base64; +use futures::future::{err, Either, FutureResult}; +use futures::{try_ready, Async, Future, Poll, Sink, Stream}; +use http::header::{self, HeaderValue}; +use http::{HttpTryFrom, StatusCode}; +use log::trace; +use rand; +use sha1::Sha1; + +use crate::body::BodyLength; +use crate::h1; +use crate::message::{ConnectionType, Head, ResponseHead}; +use crate::ws::Codec; + +use super::{ClientError, Connect, Protocol}; + +/// WebSocket's client +pub struct Client { + connector: T, +} + +impl Client<()> { + /// Create client with default connector. + pub fn default() -> Client< + impl Service< + Request = TcpConnect, + Response = impl AsyncRead + AsyncWrite, + Error = ConnectError, + > + Clone, + > { + Client::new(apply_fn(default_connector(), |msg: TcpConnect<_>, srv| { + srv.call(msg).map(|stream| stream.into_parts().0) + })) + } +} + +impl Client +where + T: Service, Error = ConnectError>, + T::Response: AsyncRead + AsyncWrite, +{ + /// Create new websocket's client factory + pub fn new(connector: T) -> Self { + Client { connector } + } +} + +impl Clone for Client +where + T: Service, Error = ConnectError> + Clone, + T::Response: AsyncRead + AsyncWrite, +{ + fn clone(&self) -> Self { + Client { + connector: self.connector.clone(), + } + } +} + +impl Service for Client +where + T: Service, Error = ConnectError>, + T::Response: AsyncRead + AsyncWrite + 'static, + T::Future: 'static, +{ + type Request = Connect; + type Response = Framed; + type Error = ClientError; + type Future = Either< + FutureResult, + ClientResponseFut, + >; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + self.connector.poll_ready().map_err(ClientError::from) + } + + fn call(&mut self, mut req: Connect) -> Self::Future { + if let Some(e) = req.err.take() { + Either::A(err(e)) + } else if let Some(e) = req.http_err.take() { + Either::A(err(e.into())) + } else { + // origin + if let Some(origin) = req.origin.take() { + req.head.headers.insert(header::ORIGIN, origin); + } + + req.head.set_connection_type(ConnectionType::Upgrade); + req.head + .headers + .insert(header::UPGRADE, HeaderValue::from_static("websocket")); + req.head.headers.insert( + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + ); + + if let Some(protocols) = req.protocols.take() { + req.head.headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::try_from(protocols.as_str()).unwrap(), + ); + } + if let Some(e) = req.http_err { + return Either::A(err(e.into())); + }; + + let mut request = req.head; + if request.uri.host().is_none() { + return Either::A(err(ClientError::InvalidUrl)); + } + + // supported protocols + let proto = if let Some(scheme) = request.uri.scheme_part() { + match Protocol::from(scheme.as_str()) { + Some(proto) => proto, + None => return Either::A(err(ClientError::InvalidUrl)), + } + } else { + return Either::A(err(ClientError::InvalidUrl)); + }; + + // Generate a random key for the `Sec-WebSocket-Key` header. + // a base64-encoded (see Section 4 of [RFC4648]) value that, + // when decoded, is 16 bytes in length (RFC 6455) + let sec_key: [u8; 16] = rand::random(); + let key = base64::encode(&sec_key); + + request.headers.insert( + header::SEC_WEBSOCKET_KEY, + HeaderValue::try_from(key.as_str()).unwrap(), + ); + + // prep connection + let connect = TcpConnect::new(request.uri.host().unwrap().to_string()) + .set_port(request.uri.port_u16().unwrap_or_else(|| proto.port())); + + let fut = Box::new( + self.connector + .call(connect) + .map_err(ClientError::from) + .and_then(move |io| { + // h1 protocol + let framed = Framed::new(io, h1::ClientCodec::default()); + framed + .send((request, BodyLength::None).into()) + .map_err(ClientError::from) + .and_then(|framed| { + framed + .into_future() + .map_err(|(e, _)| ClientError::from(e)) + }) + }), + ); + + // start handshake + Either::B(ClientResponseFut { + key, + fut, + max_size: req.max_size, + server_mode: req.server_mode, + _t: PhantomData, + }) + } + } +} + +/// Future that implementes client websocket handshake process. +/// +/// It resolves to a `Framed` instance. +pub struct ClientResponseFut +where + T: AsyncRead + AsyncWrite, +{ + fut: Box< + Future< + Item = (Option, Framed), + Error = ClientError, + >, + >, + key: String, + max_size: usize, + server_mode: bool, + _t: PhantomData, +} + +impl Future for ClientResponseFut +where + T: AsyncRead + AsyncWrite, +{ + type Item = Framed; + type Error = ClientError; + + fn poll(&mut self) -> Poll { + let (item, framed) = try_ready!(self.fut.poll()); + + let res = match item { + Some(res) => res, + None => return Err(ClientError::Disconnected), + }; + + // verify response + if res.status != StatusCode::SWITCHING_PROTOCOLS { + return Err(ClientError::InvalidResponseStatus(res.status)); + } + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = res.headers.get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + trace!("Invalid upgrade header"); + return Err(ClientError::InvalidUpgradeHeader); + } + // Check for "CONNECTION" header + if let Some(conn) = res.headers.get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_lowercase().contains("upgrade") { + trace!("Invalid connection header: {}", s); + return Err(ClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + trace!("Invalid connection header: {:?}", conn); + return Err(ClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + trace!("Missing connection header"); + return Err(ClientError::MissingConnectionHeader); + } + + if let Some(key) = res.headers.get(header::SEC_WEBSOCKET_ACCEPT) { + // field is constructed by concatenating /key/ + // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut sha1 = Sha1::new(); + sha1.update(self.key.as_ref()); + sha1.update(WS_GUID); + let encoded = base64::encode(&sha1.digest().bytes()); + if key.as_bytes() != encoded.as_bytes() { + trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, + key + ); + return Err(ClientError::InvalidChallengeResponse(encoded, key.clone())); + } + } else { + trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(ClientError::MissingWebSocketAcceptHeader); + }; + + // websockets codec + let codec = if self.server_mode { + Codec::new().max_size(self.max_size) + } else { + Codec::new().max_size(self.max_size).client_mode() + }; + + Ok(Async::Ready(framed.into_framed(codec))) + } +} diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs new file mode 100644 index 000000000..286d15f8c --- /dev/null +++ b/actix-http/src/ws/codec.rs @@ -0,0 +1,147 @@ +use actix_codec::{Decoder, Encoder}; +use bytes::{Bytes, BytesMut}; + +use super::frame::Parser; +use super::proto::{CloseReason, OpCode}; +use super::ProtocolError; + +/// `WebSocket` Message +#[derive(Debug, PartialEq)] +pub enum Message { + /// Text message + Text(String), + /// Binary message + Binary(Bytes), + /// Ping message + Ping(String), + /// Pong message + Pong(String), + /// Close message with optional reason + Close(Option), +} + +/// `WebSocket` frame +#[derive(Debug, PartialEq)] +pub enum Frame { + /// Text frame, codec does not verify utf8 encoding + Text(Option), + /// Binary frame + Binary(Option), + /// Ping message + Ping(String), + /// Pong message + Pong(String), + /// Close message with optional reason + Close(Option), +} + +#[derive(Debug)] +/// WebSockets protocol codec +pub struct Codec { + max_size: usize, + server: bool, +} + +impl Codec { + /// Create new websocket frames decoder + pub fn new() -> Codec { + Codec { + max_size: 65_536, + server: true, + } + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Set decoder to client mode. + /// + /// By default decoder works in server mode. + pub fn client_mode(mut self) -> Self { + self.server = false; + self + } +} + +impl Encoder for Codec { + type Item = Message; + type Error = ProtocolError; + + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + match item { + Message::Text(txt) => { + Parser::write_message(dst, txt, OpCode::Text, true, !self.server) + } + Message::Binary(bin) => { + Parser::write_message(dst, bin, OpCode::Binary, true, !self.server) + } + Message::Ping(txt) => { + Parser::write_message(dst, txt, OpCode::Ping, true, !self.server) + } + Message::Pong(txt) => { + Parser::write_message(dst, txt, OpCode::Pong, true, !self.server) + } + Message::Close(reason) => Parser::write_close(dst, reason, !self.server), + } + Ok(()) + } +} + +impl Decoder for Codec { + type Item = Frame; + type Error = ProtocolError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + match Parser::parse(src, self.server, self.max_size) { + Ok(Some((finished, opcode, payload))) => { + // continuation is not supported + if !finished { + return Err(ProtocolError::NoContinuation); + } + + match opcode { + OpCode::Continue => Err(ProtocolError::NoContinuation), + OpCode::Bad => Err(ProtocolError::BadOpCode), + OpCode::Close => { + if let Some(ref pl) = payload { + let close_reason = Parser::parse_close_payload(pl); + Ok(Some(Frame::Close(close_reason))) + } else { + Ok(Some(Frame::Close(None))) + } + } + OpCode::Ping => { + if let Some(ref pl) = payload { + Ok(Some(Frame::Ping(String::from_utf8_lossy(pl).into()))) + } else { + Ok(Some(Frame::Ping(String::new()))) + } + } + OpCode::Pong => { + if let Some(ref pl) = payload { + Ok(Some(Frame::Pong(String::from_utf8_lossy(pl).into()))) + } else { + Ok(Some(Frame::Pong(String::new()))) + } + } + OpCode::Binary => Ok(Some(Frame::Binary(payload))), + OpCode::Text => { + Ok(Some(Frame::Text(payload))) + //let tmp = Vec::from(payload.as_ref()); + //match String::from_utf8(tmp) { + // Ok(s) => Ok(Some(Message::Text(s))), + // Err(_) => Err(ProtocolError::BadEncoding), + //} + } + } + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } +} diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs new file mode 100644 index 000000000..d4c15627f --- /dev/null +++ b/actix-http/src/ws/frame.rs @@ -0,0 +1,383 @@ +use byteorder::{ByteOrder, LittleEndian, NetworkEndian}; +use bytes::{BufMut, Bytes, BytesMut}; +use log::debug; +use rand; + +use crate::ws::mask::apply_mask; +use crate::ws::proto::{CloseCode, CloseReason, OpCode}; +use crate::ws::ProtocolError; + +/// A struct representing a `WebSocket` frame. +#[derive(Debug)] +pub struct Parser; + +impl Parser { + fn parse_metadata( + src: &[u8], + server: bool, + max_size: usize, + ) -> Result)>, ProtocolError> { + let chunk_len = src.len(); + + let mut idx = 2; + if chunk_len < 2 { + return Ok(None); + } + + let first = src[0]; + let second = src[1]; + let finished = first & 0x80 != 0; + + // check masking + let masked = second & 0x80 != 0; + if !masked && server { + return Err(ProtocolError::UnmaskedFrame); + } else if masked && !server { + return Err(ProtocolError::MaskedFrame); + } + + // Op code + let opcode = OpCode::from(first & 0x0F); + + if let OpCode::Bad = opcode { + return Err(ProtocolError::InvalidOpcode(first & 0x0F)); + } + + let len = second & 0x7F; + let length = if len == 126 { + if chunk_len < 4 { + return Ok(None); + } + let len = NetworkEndian::read_uint(&src[idx..], 2) as usize; + idx += 2; + len + } else if len == 127 { + if chunk_len < 10 { + return Ok(None); + } + let len = NetworkEndian::read_uint(&src[idx..], 8); + if len > max_size as u64 { + return Err(ProtocolError::Overflow); + } + idx += 8; + len as usize + } else { + len as usize + }; + + // check for max allowed size + if length > max_size { + return Err(ProtocolError::Overflow); + } + + let mask = if server { + if chunk_len < idx + 4 { + return Ok(None); + } + + let mask: &[u8] = &src[idx..idx + 4]; + let mask_u32 = LittleEndian::read_u32(mask); + idx += 4; + Some(mask_u32) + } else { + None + }; + + Ok(Some((idx, finished, opcode, length, mask))) + } + + /// Parse the input stream into a frame. + pub fn parse( + src: &mut BytesMut, + server: bool, + max_size: usize, + ) -> Result)>, ProtocolError> { + // try to parse ws frame metadata + let (idx, finished, opcode, length, mask) = + match Parser::parse_metadata(src, server, max_size)? { + None => return Ok(None), + Some(res) => res, + }; + + // not enough data + if src.len() < idx + length { + return Ok(None); + } + + // remove prefix + src.split_to(idx); + + // no need for body + if length == 0 { + return Ok(Some((finished, opcode, None))); + } + + let mut data = src.split_to(length); + + // control frames must have length <= 125 + match opcode { + OpCode::Ping | OpCode::Pong if length > 125 => { + return Err(ProtocolError::InvalidLength(length)); + } + OpCode::Close if length > 125 => { + debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); + return Ok(Some((true, OpCode::Close, None))); + } + _ => (), + } + + // unmask + if let Some(mask) = mask { + apply_mask(&mut data, mask); + } + + Ok(Some((finished, opcode, Some(data)))) + } + + /// Parse the payload of a close frame. + pub fn parse_close_payload(payload: &[u8]) -> Option { + if payload.len() >= 2 { + let raw_code = NetworkEndian::read_u16(payload); + let code = CloseCode::from(raw_code); + let description = if payload.len() > 2 { + Some(String::from_utf8_lossy(&payload[2..]).into()) + } else { + None + }; + Some(CloseReason { code, description }) + } else { + None + } + } + + /// Generate binary representation + pub fn write_message>( + dst: &mut BytesMut, + pl: B, + op: OpCode, + fin: bool, + mask: bool, + ) { + let payload = pl.into(); + let one: u8 = if fin { + 0x80 | Into::::into(op) + } else { + op.into() + }; + let payload_len = payload.len(); + let (two, p_len) = if mask { + (0x80, payload_len + 4) + } else { + (0, payload_len) + }; + + if payload_len < 126 { + dst.put_slice(&[one, two | payload_len as u8]); + } else if payload_len <= 65_535 { + dst.reserve(p_len + 4); + dst.put_slice(&[one, two | 126]); + dst.put_u16_be(payload_len as u16); + } else { + dst.reserve(p_len + 10); + dst.put_slice(&[one, two | 127]); + dst.put_u64_be(payload_len as u64); + }; + + if mask { + let mask = rand::random::(); + dst.put_u32_le(mask); + dst.extend_from_slice(payload.as_ref()); + let pos = dst.len() - payload_len; + apply_mask(&mut dst[pos..], mask); + } else { + dst.put_slice(payload.as_ref()); + } + } + + /// Create a new Close control frame. + #[inline] + pub fn write_close(dst: &mut BytesMut, reason: Option, mask: bool) { + let payload = match reason { + None => Vec::new(), + Some(reason) => { + let mut code_bytes = [0; 2]; + NetworkEndian::write_u16(&mut code_bytes, reason.code.into()); + + let mut payload = Vec::from(&code_bytes[..]); + if let Some(description) = reason.description { + payload.extend(description.as_bytes()); + } + payload + } + }; + + Parser::write_message(dst, payload, OpCode::Close, true, mask) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + struct F { + finished: bool, + opcode: OpCode, + payload: Bytes, + } + + fn is_none( + frm: &Result)>, ProtocolError>, + ) -> bool { + match *frm { + Ok(None) => true, + _ => false, + } + } + + fn extract( + frm: Result)>, ProtocolError>, + ) -> F { + match frm { + Ok(Some((finished, opcode, payload))) => F { + finished, + opcode, + payload: payload + .map(|b| b.freeze()) + .unwrap_or_else(|| Bytes::from("")), + }, + _ => unreachable!("error"), + } + } + + #[test] + fn test_parse() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); + + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); + buf.extend(b"1"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"1"[..]); + } + + #[test] + fn test_parse_length0() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0000u8][..]); + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert!(frame.payload.is_empty()); + } + + #[test] + fn test_parse_length2() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); + + let mut buf = BytesMut::from(&[0b0000_0001u8, 126u8][..]); + buf.extend(&[0u8, 4u8][..]); + buf.extend(b"1234"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"1234"[..]); + } + + #[test] + fn test_parse_length4() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); + assert!(is_none(&Parser::parse(&mut buf, false, 1024))); + + let mut buf = BytesMut::from(&[0b0000_0001u8, 127u8][..]); + buf.extend(&[0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 4u8][..]); + buf.extend(b"1234"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"1234"[..]); + } + + #[test] + fn test_parse_frame_mask() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b1000_0001u8][..]); + buf.extend(b"0001"); + buf.extend(b"1"); + + assert!(Parser::parse(&mut buf, false, 1024).is_err()); + + let frame = extract(Parser::parse(&mut buf, true, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload, Bytes::from(vec![1u8])); + } + + #[test] + fn test_parse_frame_no_mask() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0001u8][..]); + buf.extend(&[1u8]); + + assert!(Parser::parse(&mut buf, true, 1024).is_err()); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(!frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload, Bytes::from(vec![1u8])); + } + + #[test] + fn test_parse_frame_max_size() { + let mut buf = BytesMut::from(&[0b0000_0001u8, 0b0000_0010u8][..]); + buf.extend(&[1u8, 1u8]); + + assert!(Parser::parse(&mut buf, true, 1).is_err()); + + if let Err(ProtocolError::Overflow) = Parser::parse(&mut buf, false, 0) { + } else { + unreachable!("error"); + } + } + + #[test] + fn test_ping_frame() { + let mut buf = BytesMut::new(); + Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + + let mut v = vec![137u8, 4u8]; + v.extend(b"data"); + assert_eq!(&buf[..], &v[..]); + } + + #[test] + fn test_pong_frame() { + let mut buf = BytesMut::new(); + Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + + let mut v = vec![138u8, 4u8]; + v.extend(b"data"); + assert_eq!(&buf[..], &v[..]); + } + + #[test] + fn test_close_frame() { + let mut buf = BytesMut::new(); + let reason = (CloseCode::Normal, "data"); + Parser::write_close(&mut buf, Some(reason.into()), false); + + let mut v = vec![136u8, 6u8, 3u8, 232u8]; + v.extend(b"data"); + assert_eq!(&buf[..], &v[..]); + } + + #[test] + fn test_empty_close_frame() { + let mut buf = BytesMut::new(); + Parser::write_close(&mut buf, None, false); + assert_eq!(&buf[..], &vec![0x88, 0x00][..]); + } +} diff --git a/actix-http/src/ws/mask.rs b/actix-http/src/ws/mask.rs new file mode 100644 index 000000000..157375417 --- /dev/null +++ b/actix-http/src/ws/mask.rs @@ -0,0 +1,149 @@ +//! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs) +#![allow(clippy::cast_ptr_alignment)] +use std::ptr::copy_nonoverlapping; +use std::slice; + +// Holds a slice guaranteed to be shorter than 8 bytes +struct ShortSlice<'a>(&'a mut [u8]); + +impl<'a> ShortSlice<'a> { + unsafe fn new(slice: &'a mut [u8]) -> Self { + // Sanity check for debug builds + debug_assert!(slice.len() < 8); + ShortSlice(slice) + } + fn len(&self) -> usize { + self.0.len() + } +} + +/// Faster version of `apply_mask()` which operates on 8-byte blocks. +#[inline] +#[allow(clippy::cast_lossless)] +pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) { + // Extend the mask to 64 bits + let mut mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64); + // Split the buffer into three segments + let (head, mid, tail) = align_buf(buf); + + // Initial unaligned segment + let head_len = head.len(); + if head_len > 0 { + xor_short(head, mask_u64); + if cfg!(target_endian = "big") { + mask_u64 = mask_u64.rotate_left(8 * head_len as u32); + } else { + mask_u64 = mask_u64.rotate_right(8 * head_len as u32); + } + } + // Aligned segment + for v in mid { + *v ^= mask_u64; + } + // Final unaligned segment + if tail.len() > 0 { + xor_short(tail, mask_u64); + } +} + +#[inline] +// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so +// inefficient, it could be done better. The compiler does not understand that +// a `ShortSlice` must be smaller than a u64. +#[allow(clippy::needless_pass_by_value)] +fn xor_short(buf: ShortSlice, mask: u64) { + // Unsafe: we know that a `ShortSlice` fits in a u64 + unsafe { + let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len()); + let mut b: u64 = 0; + #[allow(trivial_casts)] + copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); + b ^= mask; + #[allow(trivial_casts)] + copy_nonoverlapping(&b as *const _ as *const u8, ptr, len); + } +} + +#[inline] +// Unsafe: caller must ensure the buffer has the correct size and alignment +unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] { + // Assert correct size and alignment in debug builds + debug_assert!(buf.len().trailing_zeros() >= 3); + debug_assert!((buf.as_ptr() as usize).trailing_zeros() >= 3); + + slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u64, buf.len() >> 3) +} + +#[inline] +// Splits a slice into three parts: an unaligned short head and tail, plus an aligned +// u64 mid section. +fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) { + let start_ptr = buf.as_ptr() as usize; + let end_ptr = start_ptr + buf.len(); + + // Round *up* to next aligned boundary for start + let start_aligned = (start_ptr + 7) & !0x7; + // Round *down* to last aligned boundary for end + let end_aligned = end_ptr & !0x7; + + if end_aligned >= start_aligned { + // We have our three segments (head, mid, tail) + let (tmp, tail) = buf.split_at_mut(end_aligned - start_ptr); + let (head, mid) = tmp.split_at_mut(start_aligned - start_ptr); + + // Unsafe: we know the middle section is correctly aligned, and the outer + // sections are smaller than 8 bytes + unsafe { (ShortSlice::new(head), cast_slice(mid), ShortSlice(tail)) } + } else { + // We didn't cross even one aligned boundary! + + // Unsafe: The outer sections are smaller than 8 bytes + unsafe { (ShortSlice::new(buf), &mut [], ShortSlice::new(&mut [])) } + } +} + +#[cfg(test)] +mod tests { + use super::apply_mask; + use byteorder::{ByteOrder, LittleEndian}; + + /// A safe unoptimized mask application. + fn apply_mask_fallback(buf: &mut [u8], mask: &[u8; 4]) { + for (i, byte) in buf.iter_mut().enumerate() { + *byte ^= mask[i & 3]; + } + } + + #[test] + fn test_apply_mask() { + let mask = [0x6d, 0xb6, 0xb2, 0x80]; + let mask_u32: u32 = LittleEndian::read_u32(&mask); + + let unmasked = vec![ + 0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, + 0x74, 0xf9, 0x12, 0x03, + ]; + + // Check masking with proper alignment. + { + let mut masked = unmasked.clone(); + apply_mask_fallback(&mut masked, &mask); + + let mut masked_fast = unmasked.clone(); + apply_mask(&mut masked_fast, mask_u32); + + assert_eq!(masked, masked_fast); + } + + // Check masking without alignment. + { + let mut masked = unmasked.clone(); + apply_mask_fallback(&mut masked[1..], &mask); + + let mut masked_fast = unmasked.clone(); + apply_mask(&mut masked_fast[1..], mask_u32); + + assert_eq!(masked, masked_fast); + } + } +} diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs new file mode 100644 index 000000000..88fabde94 --- /dev/null +++ b/actix-http/src/ws/mod.rs @@ -0,0 +1,320 @@ +//! WebSocket protocol support. +//! +//! To setup a `WebSocket`, first do web socket handshake then on success +//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to +//! communicate with the peer. +use std::io; + +use derive_more::{Display, From}; +use http::{header, Method, StatusCode}; + +use crate::error::ResponseError; +use crate::httpmessage::HttpMessage; +use crate::request::Request; +use crate::response::{Response, ResponseBuilder}; + +mod client; +mod codec; +mod frame; +mod mask; +mod proto; +mod service; +mod transport; + +pub use self::client::{Client, ClientError, Connect}; +pub use self::codec::{Codec, Frame, Message}; +pub use self::frame::Parser; +pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; +pub use self::service::VerifyWebSockets; +pub use self::transport::Transport; + +/// Websocket protocol errors +#[derive(Debug, Display, From)] +pub enum ProtocolError { + /// Received an unmasked frame from client + #[display(fmt = "Received an unmasked frame from client")] + UnmaskedFrame, + /// Received a masked frame from server + #[display(fmt = "Received a masked frame from server")] + MaskedFrame, + /// Encountered invalid opcode + #[display(fmt = "Invalid opcode: {}", _0)] + InvalidOpcode(u8), + /// Invalid control frame length + #[display(fmt = "Invalid control frame length: {}", _0)] + InvalidLength(usize), + /// Bad web socket op code + #[display(fmt = "Bad web socket op code")] + BadOpCode, + /// A payload reached size limit. + #[display(fmt = "A payload reached size limit.")] + Overflow, + /// Continuation is not supported + #[display(fmt = "Continuation is not supported.")] + NoContinuation, + /// Bad utf-8 encoding + #[display(fmt = "Bad utf-8 encoding.")] + BadEncoding, + /// Io error + #[display(fmt = "io error: {}", _0)] + Io(io::Error), +} + +impl ResponseError for ProtocolError {} + +/// Websocket handshake errors +#[derive(PartialEq, Debug, Display)] +pub enum HandshakeError { + /// Only get method is allowed + #[display(fmt = "Method not allowed")] + GetMethodRequired, + /// Upgrade header if not set to websocket + #[display(fmt = "Websocket upgrade is expected")] + NoWebsocketUpgrade, + /// Connection header is not set to upgrade + #[display(fmt = "Connection upgrade is expected")] + NoConnectionUpgrade, + /// Websocket version header is not set + #[display(fmt = "Websocket version header is required")] + NoVersionHeader, + /// Unsupported websocket version + #[display(fmt = "Unsupported version")] + UnsupportedVersion, + /// Websocket key is not set or wrong + #[display(fmt = "Unknown websocket key")] + BadWebsocketKey, +} + +impl ResponseError for HandshakeError { + fn error_response(&self) -> Response { + match *self { + HandshakeError::GetMethodRequired => Response::MethodNotAllowed() + .header(header::ALLOW, "GET") + .finish(), + HandshakeError::NoWebsocketUpgrade => Response::BadRequest() + .reason("No WebSocket UPGRADE header found") + .finish(), + HandshakeError::NoConnectionUpgrade => Response::BadRequest() + .reason("No CONNECTION upgrade") + .finish(), + HandshakeError::NoVersionHeader => Response::BadRequest() + .reason("Websocket version header is required") + .finish(), + HandshakeError::UnsupportedVersion => Response::BadRequest() + .reason("Unsupported version") + .finish(), + HandshakeError::BadWebsocketKey => { + Response::BadRequest().reason("Handshake error").finish() + } + } + } +} + +/// Verify `WebSocket` handshake request and create handshake reponse. +// /// `protocols` is a sequence of known protocols. On successful handshake, +// /// the returned response headers contain the first protocol in this list +// /// which the server also knows. +pub fn handshake(req: &Request) -> Result { + verify_handshake(req)?; + Ok(handshake_response(req)) +} + +/// Verify `WebSocket` handshake request. +// /// `protocols` is a sequence of known protocols. On successful handshake, +// /// the returned response headers contain the first protocol in this list +// /// which the server also knows. +pub fn verify_handshake(req: &Request) -> Result<(), HandshakeError> { + // WebSocket accepts only GET + if *req.method() != Method::GET { + return Err(HandshakeError::GetMethodRequired); + } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = req.headers().get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + return Err(HandshakeError::NoWebsocketUpgrade); + } + + // Upgrade connection + if !req.upgrade() { + return Err(HandshakeError::NoConnectionUpgrade); + } + + // check supported version + if !req.headers().contains_key(header::SEC_WEBSOCKET_VERSION) { + return Err(HandshakeError::NoVersionHeader); + } + let supported_ver = { + if let Some(hdr) = req.headers().get(header::SEC_WEBSOCKET_VERSION) { + hdr == "13" || hdr == "8" || hdr == "7" + } else { + false + } + }; + if !supported_ver { + return Err(HandshakeError::UnsupportedVersion); + } + + // check client handshake for validity + if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { + return Err(HandshakeError::BadWebsocketKey); + } + Ok(()) +} + +/// Create websocket's handshake response +/// +/// This function returns handshake `Response`, ready to send to peer. +pub fn handshake_response(req: &Request) -> ResponseBuilder { + let key = { + let key = req.headers().get(header::SEC_WEBSOCKET_KEY).unwrap(); + proto::hash_key(key.as_ref()) + }; + + Response::build(StatusCode::SWITCHING_PROTOCOLS) + .upgrade("websocket") + .header(header::TRANSFER_ENCODING, "chunked") + .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) + .take() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::TestRequest; + use http::{header, Method}; + + #[test] + fn test_handshake() { + let req = TestRequest::default().method(Method::POST).finish(); + assert_eq!( + HandshakeError::GetMethodRequired, + verify_handshake(&req).err().unwrap() + ); + + let req = TestRequest::default().finish(); + assert_eq!( + HandshakeError::NoWebsocketUpgrade, + verify_handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header(header::UPGRADE, header::HeaderValue::from_static("test")) + .finish(); + assert_eq!( + HandshakeError::NoWebsocketUpgrade, + verify_handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .finish(); + assert_eq!( + HandshakeError::NoConnectionUpgrade, + verify_handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .finish(); + assert_eq!( + HandshakeError::NoVersionHeader, + verify_handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("5"), + ) + .finish(); + assert_eq!( + HandshakeError::UnsupportedVersion, + verify_handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .finish(); + assert_eq!( + HandshakeError::BadWebsocketKey, + verify_handshake(&req).err().unwrap() + ); + + let req = TestRequest::default() + .header( + header::UPGRADE, + header::HeaderValue::from_static("websocket"), + ) + .header( + header::CONNECTION, + header::HeaderValue::from_static("upgrade"), + ) + .header( + header::SEC_WEBSOCKET_VERSION, + header::HeaderValue::from_static("13"), + ) + .header( + header::SEC_WEBSOCKET_KEY, + header::HeaderValue::from_static("13"), + ) + .finish(); + assert_eq!( + StatusCode::SWITCHING_PROTOCOLS, + handshake_response(&req).finish().status() + ); + } + + #[test] + fn test_wserror_http_response() { + let resp: Response = HandshakeError::GetMethodRequired.error_response(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + let resp: Response = HandshakeError::NoWebsocketUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::NoConnectionUpgrade.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::NoVersionHeader.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::UnsupportedVersion.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + let resp: Response = HandshakeError::BadWebsocketKey.error_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } +} diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs new file mode 100644 index 000000000..eef874741 --- /dev/null +++ b/actix-http/src/ws/proto.rs @@ -0,0 +1,318 @@ +use base64; +use sha1; +use std::convert::{From, Into}; +use std::fmt; + +use self::OpCode::*; +/// Operation codes as part of rfc6455. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum OpCode { + /// Indicates a continuation frame of a fragmented message. + Continue, + /// Indicates a text data frame. + Text, + /// Indicates a binary data frame. + Binary, + /// Indicates a close control frame. + Close, + /// Indicates a ping control frame. + Ping, + /// Indicates a pong control frame. + Pong, + /// Indicates an invalid opcode was received. + Bad, +} + +impl fmt::Display for OpCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Continue => write!(f, "CONTINUE"), + Text => write!(f, "TEXT"), + Binary => write!(f, "BINARY"), + Close => write!(f, "CLOSE"), + Ping => write!(f, "PING"), + Pong => write!(f, "PONG"), + Bad => write!(f, "BAD"), + } + } +} + +impl Into for OpCode { + fn into(self) -> u8 { + match self { + Continue => 0, + Text => 1, + Binary => 2, + Close => 8, + Ping => 9, + Pong => 10, + Bad => { + debug_assert!( + false, + "Attempted to convert invalid opcode to u8. This is a bug." + ); + 8 // if this somehow happens, a close frame will help us tear down quickly + } + } + } +} + +impl From for OpCode { + fn from(byte: u8) -> OpCode { + match byte { + 0 => Continue, + 1 => Text, + 2 => Binary, + 8 => Close, + 9 => Ping, + 10 => Pong, + _ => Bad, + } + } +} + +use self::CloseCode::*; +/// Status code used to indicate why an endpoint is closing the `WebSocket` +/// connection. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum CloseCode { + /// Indicates a normal closure, meaning that the purpose for + /// which the connection was established has been fulfilled. + Normal, + /// Indicates that an endpoint is "going away", such as a server + /// going down or a browser having navigated away from a page. + Away, + /// Indicates that an endpoint is terminating the connection due + /// to a protocol error. + Protocol, + /// Indicates that an endpoint is terminating the connection + /// because it has received a type of data it cannot accept (e.g., an + /// endpoint that understands only text data MAY send this if it + /// receives a binary message). + Unsupported, + /// Indicates an abnormal closure. If the abnormal closure was due to an + /// error, this close code will not be used. Instead, the `on_error` method + /// of the handler will be called with the error. However, if the connection + /// is simply dropped, without an error, this close code will be sent to the + /// handler. + Abnormal, + /// Indicates that an endpoint is terminating the connection + /// because it has received data within a message that was not + /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// data within a text message). + Invalid, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that violates its policy. This + /// is a generic status code that can be returned when there is no + /// other more suitable status code (e.g., Unsupported or Size) or if there + /// is a need to hide specific details about the policy. + Policy, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that is too big for it to + /// process. + Size, + /// Indicates that an endpoint (client) is terminating the + /// connection because it has expected the server to negotiate one or + /// more extension, but the server didn't return them in the response + /// message of the WebSocket handshake. The list of extensions that + /// are needed should be given as the reason for closing. + /// Note that this status code is not used by the server, because it + /// can fail the WebSocket handshake instead. + Extension, + /// Indicates that a server is terminating the connection because + /// it encountered an unexpected condition that prevented it from + /// fulfilling the request. + Error, + /// Indicates that the server is restarting. A client may choose to + /// reconnect, and if it does, it should use a randomized delay of 5-30 + /// seconds between attempts. + Restart, + /// Indicates that the server is overloaded and the client should either + /// connect to a different IP (when multiple targets exist), or + /// reconnect to the same IP when a user has performed an action. + Again, + #[doc(hidden)] + Tls, + #[doc(hidden)] + Other(u16), +} + +impl Into for CloseCode { + fn into(self) -> u16 { + match self { + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Other(code) => code, + } + } +} + +impl From for CloseCode { + fn from(code: u16) -> CloseCode { + match code { + 1000 => Normal, + 1001 => Away, + 1002 => Protocol, + 1003 => Unsupported, + 1006 => Abnormal, + 1007 => Invalid, + 1008 => Policy, + 1009 => Size, + 1010 => Extension, + 1011 => Error, + 1012 => Restart, + 1013 => Again, + 1015 => Tls, + _ => Other(code), + } + } +} + +#[derive(Debug, Eq, PartialEq, Clone)] +/// Reason for closing the connection +pub struct CloseReason { + /// Exit code + pub code: CloseCode, + /// Optional description of the exit code + pub description: Option, +} + +impl From for CloseReason { + fn from(code: CloseCode) -> Self { + CloseReason { + code, + description: None, + } + } +} + +impl> From<(CloseCode, T)> for CloseReason { + fn from(info: (CloseCode, T)) -> Self { + CloseReason { + code: info.0, + description: Some(info.1.into()), + } + } +} + +static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +// TODO: hash is always same size, we dont need String +pub fn hash_key(key: &[u8]) -> String { + let mut hasher = sha1::Sha1::new(); + + hasher.update(key); + hasher.update(WS_GUID.as_bytes()); + + base64::encode(&hasher.digest().bytes()) +} + +#[cfg(test)] +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + + macro_rules! opcode_into { + ($from:expr => $opcode:pat) => { + match OpCode::from($from) { + e @ $opcode => (), + e => unreachable!("{:?}", e), + } + }; + } + + macro_rules! opcode_from { + ($from:expr => $opcode:pat) => { + let res: u8 = $from.into(); + match res { + e @ $opcode => (), + e => unreachable!("{:?}", e), + } + }; + } + + #[test] + fn test_to_opcode() { + opcode_into!(0 => OpCode::Continue); + opcode_into!(1 => OpCode::Text); + opcode_into!(2 => OpCode::Binary); + opcode_into!(8 => OpCode::Close); + opcode_into!(9 => OpCode::Ping); + opcode_into!(10 => OpCode::Pong); + opcode_into!(99 => OpCode::Bad); + } + + #[test] + fn test_from_opcode() { + opcode_from!(OpCode::Continue => 0); + opcode_from!(OpCode::Text => 1); + opcode_from!(OpCode::Binary => 2); + opcode_from!(OpCode::Close => 8); + opcode_from!(OpCode::Ping => 9); + opcode_from!(OpCode::Pong => 10); + } + + #[test] + #[should_panic] + fn test_from_opcode_debug() { + opcode_from!(OpCode::Bad => 99); + } + + #[test] + fn test_from_opcode_display() { + assert_eq!(format!("{}", OpCode::Continue), "CONTINUE"); + assert_eq!(format!("{}", OpCode::Text), "TEXT"); + assert_eq!(format!("{}", OpCode::Binary), "BINARY"); + assert_eq!(format!("{}", OpCode::Close), "CLOSE"); + assert_eq!(format!("{}", OpCode::Ping), "PING"); + assert_eq!(format!("{}", OpCode::Pong), "PONG"); + assert_eq!(format!("{}", OpCode::Bad), "BAD"); + } + + #[test] + fn closecode_from_u16() { + assert_eq!(CloseCode::from(1000u16), CloseCode::Normal); + assert_eq!(CloseCode::from(1001u16), CloseCode::Away); + assert_eq!(CloseCode::from(1002u16), CloseCode::Protocol); + assert_eq!(CloseCode::from(1003u16), CloseCode::Unsupported); + assert_eq!(CloseCode::from(1006u16), CloseCode::Abnormal); + assert_eq!(CloseCode::from(1007u16), CloseCode::Invalid); + assert_eq!(CloseCode::from(1008u16), CloseCode::Policy); + assert_eq!(CloseCode::from(1009u16), CloseCode::Size); + assert_eq!(CloseCode::from(1010u16), CloseCode::Extension); + assert_eq!(CloseCode::from(1011u16), CloseCode::Error); + assert_eq!(CloseCode::from(1012u16), CloseCode::Restart); + assert_eq!(CloseCode::from(1013u16), CloseCode::Again); + assert_eq!(CloseCode::from(1015u16), CloseCode::Tls); + assert_eq!(CloseCode::from(2000u16), CloseCode::Other(2000)); + } + + #[test] + fn closecode_into_u16() { + assert_eq!(1000u16, Into::::into(CloseCode::Normal)); + assert_eq!(1001u16, Into::::into(CloseCode::Away)); + assert_eq!(1002u16, Into::::into(CloseCode::Protocol)); + assert_eq!(1003u16, Into::::into(CloseCode::Unsupported)); + assert_eq!(1006u16, Into::::into(CloseCode::Abnormal)); + assert_eq!(1007u16, Into::::into(CloseCode::Invalid)); + assert_eq!(1008u16, Into::::into(CloseCode::Policy)); + assert_eq!(1009u16, Into::::into(CloseCode::Size)); + assert_eq!(1010u16, Into::::into(CloseCode::Extension)); + assert_eq!(1011u16, Into::::into(CloseCode::Error)); + assert_eq!(1012u16, Into::::into(CloseCode::Restart)); + assert_eq!(1013u16, Into::::into(CloseCode::Again)); + assert_eq!(1015u16, Into::::into(CloseCode::Tls)); + assert_eq!(2000u16, Into::::into(CloseCode::Other(2000))); + } +} diff --git a/actix-http/src/ws/service.rs b/actix-http/src/ws/service.rs new file mode 100644 index 000000000..f3b066053 --- /dev/null +++ b/actix-http/src/ws/service.rs @@ -0,0 +1,52 @@ +use std::marker::PhantomData; + +use actix_codec::Framed; +use actix_service::{NewService, Service}; +use futures::future::{ok, FutureResult}; +use futures::{Async, IntoFuture, Poll}; + +use crate::h1::Codec; +use crate::request::Request; + +use super::{verify_handshake, HandshakeError}; + +pub struct VerifyWebSockets { + _t: PhantomData, +} + +impl Default for VerifyWebSockets { + fn default() -> Self { + VerifyWebSockets { _t: PhantomData } + } +} + +impl NewService for VerifyWebSockets { + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type InitError = (); + type Service = VerifyWebSockets; + type Future = FutureResult; + + fn new_service(&self, _: &()) -> Self::Future { + ok(VerifyWebSockets { _t: PhantomData }) + } +} + +impl Service for VerifyWebSockets { + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); + type Future = FutureResult; + + fn poll_ready(&mut self) -> Poll<(), Self::Error> { + Ok(Async::Ready(())) + } + + fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { + match verify_handshake(&req) { + Err(e) => Err((e, framed)).into_future(), + Ok(_) => Ok((req, framed)).into_future(), + } + } +} diff --git a/actix-http/src/ws/transport.rs b/actix-http/src/ws/transport.rs new file mode 100644 index 000000000..da7782be5 --- /dev/null +++ b/actix-http/src/ws/transport.rs @@ -0,0 +1,49 @@ +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_service::{IntoService, Service}; +use actix_utils::framed::{FramedTransport, FramedTransportError}; +use futures::{Future, Poll}; + +use super::{Codec, Frame, Message}; + +pub struct Transport +where + S: Service + 'static, + T: AsyncRead + AsyncWrite, +{ + inner: FramedTransport, +} + +impl Transport +where + T: AsyncRead + AsyncWrite, + S: Service, + S::Future: 'static, + S::Error: 'static, +{ + pub fn new>(io: T, service: F) -> Self { + Transport { + inner: FramedTransport::new(Framed::new(io, Codec::new()), service), + } + } + + pub fn with>(framed: Framed, service: F) -> Self { + Transport { + inner: FramedTransport::new(framed, service), + } + } +} + +impl Future for Transport +where + T: AsyncRead + AsyncWrite, + S: Service, + S::Future: 'static, + S::Error: 'static, +{ + type Item = (); + type Error = FramedTransportError; + + fn poll(&mut self) -> Poll { + self.inner.poll() + } +} diff --git a/actix-http/test-server/Cargo.toml b/actix-http/test-server/Cargo.toml new file mode 100644 index 000000000..313bc55cf --- /dev/null +++ b/actix-http/test-server/Cargo.toml @@ -0,0 +1,60 @@ +[package] +name = "actix-http-test" +version = "0.1.0" +authors = ["Nikolay Kim "] +description = "Actix http" +readme = "README.md" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +documentation = "https://actix.rs/api/actix-web/stable/actix_web/" +categories = ["network-programming", "asynchronous", + "web-programming::http-server", + "web-programming::websocket"] +license = "MIT/Apache-2.0" +exclude = [".gitignore", ".travis.yml", ".cargo/config", "appveyor.yml"] +edition = "2018" + +[package.metadata.docs.rs] +features = ["session"] + +[lib] +name = "actix_http_test" +path = "src/lib.rs" + +[features] +default = ["session"] + +# sessions feature, session require "ring" crate and c compiler +session = ["cookie/secure"] + +# openssl +ssl = ["openssl", "actix-http/ssl", "actix-server/ssl", "awc/ssl"] + +[dependencies] +actix-codec = "0.1.1" +actix-rt = "0.2.1" +actix-http = { path=".." } +actix-service = "0.3.4" +actix-server = "0.4.0" +actix-utils = "0.3.4" +awc = { git = "https://github.com/actix/actix-web.git" } + +base64 = "0.10" +bytes = "0.4" +cookie = { version="0.11", features=["percent-encode"] } +futures = "0.1" +http = "0.1.8" +log = "0.4" +env_logger = "0.6" +net2 = "0.2" +serde = "1.0" +serde_json = "1.0" +sha1 = "0.6" +slab = "0.4" +serde_urlencoded = "0.5.3" +time = "0.1" +tokio-tcp = "0.1" +tokio-timer = "0.2" + +openssl = { version="0.10", optional = true } diff --git a/actix-http/test-server/src/lib.rs b/actix-http/test-server/src/lib.rs new file mode 100644 index 000000000..77329e700 --- /dev/null +++ b/actix-http/test-server/src/lib.rs @@ -0,0 +1,222 @@ +//! Various helpers for Actix applications to use during testing. +use std::sync::mpsc; +use std::{net, thread, time}; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::client::Connector; +use actix_http::ws; +use actix_rt::{Runtime, System}; +use actix_server::{Server, StreamServiceFactory}; +use actix_service::Service; +use awc::{Client, ClientRequest}; +use futures::future::{lazy, Future}; +use http::Method; +use net2::TcpBuilder; + +/// The `TestServer` type. +/// +/// `TestServer` is very simple test server that simplify process of writing +/// integration tests cases for actix web applications. +/// +/// # Examples +/// +/// ```rust +/// # extern crate actix_web; +/// # use actix_web::*; +/// # +/// # fn my_handler(req: &HttpRequest) -> HttpResponse { +/// # HttpResponse::Ok().into() +/// # } +/// # +/// # fn main() { +/// use actix_web::test::TestServer; +/// +/// let mut srv = TestServer::new(|app| app.handler(my_handler)); +/// +/// let req = srv.get().finish().unwrap(); +/// let response = srv.execute(req.send()).unwrap(); +/// assert!(response.status().is_success()); +/// # } +/// ``` +pub struct TestServer; + +/// Test server controller +pub struct TestServerRuntime { + addr: net::SocketAddr, + rt: Runtime, + client: Client, +} + +impl TestServer { + /// Start new test server with application factory + pub fn new(factory: F) -> TestServerRuntime { + let (tx, rx) = mpsc::channel(); + + // run server in separate thread + thread::spawn(move || { + let sys = System::new("actix-test-server"); + let tcp = net::TcpListener::bind("127.0.0.1:0").unwrap(); + let local_addr = tcp.local_addr().unwrap(); + + Server::build() + .listen("test", tcp, factory)? + .workers(1) + .disable_signals() + .start(); + + tx.send((System::current(), local_addr)).unwrap(); + sys.run() + }); + + let (system, addr) = rx.recv().unwrap(); + let mut rt = Runtime::new().unwrap(); + + let client = rt + .block_on(lazy(move || { + let connector = { + #[cfg(feature = "ssl")] + { + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + + let mut builder = + SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder.set_alpn_protos(b"\x02h2\x08http/1.1").map_err( + |e| log::error!("Can not set alpn protocol: {:?}", e), + ); + Connector::new() + .timeout(time::Duration::from_millis(500)) + .ssl(builder.build()) + .service() + } + #[cfg(not(feature = "ssl"))] + { + Connector::new() + .timeout(time::Duration::from_millis(500)) + .service() + } + }; + + Ok::(Client::build().connector(connector).finish()) + })) + .unwrap(); + System::set_current(system); + TestServerRuntime { addr, rt, client } + } + + /// Get firat available unused address + pub fn unused_addr() -> net::SocketAddr { + let addr: net::SocketAddr = "127.0.0.1:0".parse().unwrap(); + let socket = TcpBuilder::new_v4().unwrap(); + socket.bind(&addr).unwrap(); + socket.reuse_address(true).unwrap(); + let tcp = socket.to_tcp_listener().unwrap(); + tcp.local_addr().unwrap() + } +} + +impl TestServerRuntime { + /// Execute future on current core + pub fn block_on(&mut self, fut: F) -> Result + where + F: Future, + { + self.rt.block_on(fut) + } + + /// Execute function on current core + pub fn execute(&mut self, fut: F) -> R + where + F: FnOnce() -> R, + { + self.rt.block_on(lazy(|| Ok::<_, ()>(fut()))).unwrap() + } + + /// Construct test server url + pub fn addr(&self) -> net::SocketAddr { + self.addr + } + + /// Construct test server url + pub fn url(&self, uri: &str) -> String { + if uri.starts_with('/') { + format!("http://127.0.0.1:{}{}", self.addr.port(), uri) + } else { + format!("http://127.0.0.1:{}/{}", self.addr.port(), uri) + } + } + + /// Construct test https server url + pub fn surl(&self, uri: &str) -> String { + if uri.starts_with('/') { + format!("https://127.0.0.1:{}{}", self.addr.port(), uri) + } else { + format!("https://127.0.0.1:{}/{}", self.addr.port(), uri) + } + } + + /// Create `GET` request + pub fn get(&self) -> ClientRequest { + self.client.get(self.url("/").as_str()) + } + + /// Create https `GET` request + pub fn sget(&self) -> ClientRequest { + self.client.get(self.surl("/").as_str()) + } + + /// Create `POST` request + pub fn post(&self) -> ClientRequest { + self.client.post(self.url("/").as_str()) + } + + /// Create https `POST` request + pub fn spost(&self) -> ClientRequest { + self.client.post(self.surl("/").as_str()) + } + + /// Create `HEAD` request + pub fn head(&self) -> ClientRequest { + self.client.head(self.url("/").as_str()) + } + + /// Create https `HEAD` request + pub fn shead(&self) -> ClientRequest { + self.client.head(self.surl("/").as_str()) + } + + /// Connect to test http server + pub fn request>(&self, method: Method, path: S) -> ClientRequest { + self.client.request(method, path.as_ref()) + } + + /// Stop http server + fn stop(&mut self) { + System::current().stop(); + } +} + +impl TestServerRuntime { + /// Connect to websocket server at a given path + pub fn ws_at( + &mut self, + path: &str, + ) -> Result, ws::ClientError> { + let url = self.url(path); + self.rt + .block_on(lazy(|| ws::Client::default().call(ws::Connect::new(url)))) + } + + /// Connect to a websocket server + pub fn ws( + &mut self, + ) -> Result, ws::ClientError> { + self.ws_at("/") + } +} + +impl Drop for TestServerRuntime { + fn drop(&mut self) { + self.stop() + } +} diff --git a/actix-http/tests/cert.pem b/actix-http/tests/cert.pem new file mode 100644 index 000000000..5e195d98d --- /dev/null +++ b/actix-http/tests/cert.pem @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE----- +MIICljCCAX4CCQDFdWu66640QjANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 +czAeFw0xOTAyMDQyMzEyNTBaFw0yMDAyMDQyMzEyNTBaMA0xCzAJBgNVBAYTAnVz +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzZUXMnS5X8HWxTvHAc82 +Q2d32fiPQGtD+fp3OV90l6RC9jgMdH4yTVUgX5mYYcW0k89RaP8g61H6b76F9gcd +yZ1idqKI1AU9aeBUPV8wkrouhR/6Omv8fA7yr9tVmNo53jPN7WyKoBoU0r7Yj9Ez +g3qjv/808Jlgby3EhduruyyfdvSt5ZFXnOz2D3SF9DS4yrM2jSw4ZTuoVMfZ8vZe +FVzLo/+sV8qokU6wBTEOAmZQ7e/zZV4qAoH2Z3Vj/uD1Zr/MXYyh81RdXpDqIXwV +Z29LEOa2eTGFEdvfG+tdvvuIvSdF3+WbLrwn2ECfwJ8zmKyTauPRV4pj7ks+wkBI +EQIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB6dmuWBOpFfDdu0mdsDb8XnJY1svjH +4kbztXhjQJ/WuhCUIwvXFyz9dqQCq+TbJUbUEzZJEfaq1uaI3iB5wd35ArSoAGJA +k0lonzyeSM+cmNOe/5BPqWhd1qPwbsfgMoCCkZUoTT5Rvw6yt00XIqZzMqrsvRBX +hAcUW3zBtFQNP6aQqsMdn4ClZE0WHf+LzWy2NQh+Sf46tSYBHELfdUawgR789PB4 +/gNjAeklq06JmE/3gELijwaijVIuUsMC9ua//ITk4YIFpqanPtka+7BpfTegPGNs +HCj1g7Jot97oQMuvDOJeso91aiSA+gutepCClZICT8LxNRkY3ZlXYp92 +-----END CERTIFICATE----- diff --git a/actix-http/tests/identity.pfx b/actix-http/tests/identity.pfx new file mode 100644 index 000000000..946e3b8b8 Binary files /dev/null and b/actix-http/tests/identity.pfx differ diff --git a/actix-http/tests/key.pem b/actix-http/tests/key.pem new file mode 100644 index 000000000..50ded0ce0 --- /dev/null +++ b/actix-http/tests/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDNlRcydLlfwdbF +O8cBzzZDZ3fZ+I9Aa0P5+nc5X3SXpEL2OAx0fjJNVSBfmZhhxbSTz1Fo/yDrUfpv +voX2Bx3JnWJ2oojUBT1p4FQ9XzCSui6FH/o6a/x8DvKv21WY2jneM83tbIqgGhTS +vtiP0TODeqO//zTwmWBvLcSF26u7LJ929K3lkVec7PYPdIX0NLjKszaNLDhlO6hU +x9ny9l4VXMuj/6xXyqiRTrAFMQ4CZlDt7/NlXioCgfZndWP+4PVmv8xdjKHzVF1e +kOohfBVnb0sQ5rZ5MYUR298b612++4i9J0Xf5ZsuvCfYQJ/AnzOYrJNq49FXimPu +Sz7CQEgRAgMBAAECggEBALC547EaKmko5wmyM4dYq9sRzTPxuqO0EkGIkIkfh8j8 +ChxDXmGeQnu8HBJSpW4XWP5fkCpkd9YTKOh6rgorX+37f7NgUaOBxaOIlqITfFwF +9Qu3y5IBVpEHAJUwRcsaffiILBRX5GtxQElSijRHsLLr8GySZN4X25B3laNEjcJe +NWJrDaxOn0m0MMGRvBpM8PaZu1Mn9NWxt04b/fteVLdN4TAcuY9TgvVZBq92S2FM +qvZcnJCQckNOuMOptVdP45qPkerKUohpOcqBfIiWFaalC378jE3Dm68p7slt3R6y +I1wVqCI4+MZfM3CtKcYJV0fdqklJCvXORvRiT8OZKakCgYEA5YnhgXOu4CO4DR1T +Lacv716DPyHeKVa6TbHhUhWw4bLwNLUsEL98jeU9SZ6VH8enBlDm5pCsp2i//t9n +8hoykN4L0rS4EyAGENouTRkLhtHfjTAKTKDK8cNvEaS8NOBJWrI0DTiHtFbCRBvI +zRx5VhrB5H4DDbqn7QV9g+GBKvMCgYEA5Ug3bN0RNUi3KDoIRcnWc06HsX307su7 +tB4cGqXJqVOJCrkk5sefGF502+W8m3Ldjaakr+Q9BoOdZX6boZnFtVetT8Hyzk1C +Rkiyz3GcwovOkQK//UmljsuRjgHF+PuQGX5ol4YlJtXU21k5bCsi1Tmyp7IufiGV +AQRMVZVbeesCgYA/QBZGwKTgmJcf7gO8ocRAto999wwr4f0maazIHLgICXHNZFsH +JmzhANk5jxxSjIaG5AYsZJNe8ittxQv0l6l1Z+pkHm5Wvs1NGYIGtq8JcI2kbyd3 +ZBtoMU1K1FUUUPWFq3NSbVBfrkSL1ggoFP+ObYMePmcDAntBgfDLRXl9ZwKBgQCt +/dh5l2UIn27Gawt+EkXX6L8WVTQ6xoZhj/vZyPe4tDip14gGTXQQ5RUfDj7LZCZ2 +6P/OrpAU0mnt7F8kCfI7xBY0EUU1gvGJLn/q5heElt2hs4mIJ4woSZjiP7xBTn2y +qveqDNVCnEBUWGg4Cp/7WTaXBaM8ejV9uQpIY/gwEwKBgQCCYnd9fD8L4nGyOLoD +eUzMV7G8TZfinlxCNMVXfpn4Z8OaYHOk5NiujHK55w4ghx06NQw038qhnX0ogbjU +caWOwCIbrYgx2fwYuOZbJFXdjWlpjIK3RFOcbNCNgCRLT6Lgz4uZYZ9RVftADvMi +zR1QsLWnIvARbTtOPfZqizT2gQ== +-----END PRIVATE KEY----- diff --git a/actix-http/tests/test.binary b/actix-http/tests/test.binary new file mode 100644 index 000000000..ef8ff0245 --- /dev/null +++ b/actix-http/tests/test.binary @@ -0,0 +1 @@ +TǑɂV2vI\R˙evD:藽RVYp;Gp!2C. pA !ߦx j+UcXc%;"yAI \ No newline at end of file diff --git a/actix-http/tests/test.png b/actix-http/tests/test.png new file mode 100644 index 000000000..6b7cdc0b8 Binary files /dev/null and b/actix-http/tests/test.png differ diff --git a/actix-http/tests/test_client.rs b/actix-http/tests/test_client.rs new file mode 100644 index 000000000..1ca7437d6 --- /dev/null +++ b/actix-http/tests/test_client.rs @@ -0,0 +1,98 @@ +use actix_service::NewService; +use bytes::{Bytes, BytesMut}; +use futures::future::{self, ok}; +use futures::{Future, Stream}; + +use actix_http::{ + error::PayloadError, http, HttpMessage, HttpService, Request, Response, +}; +use actix_http_test::TestServer; + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +fn load_body(stream: S) -> impl Future +where + S: Stream, +{ + stream.fold(BytesMut::new(), move |mut body, chunk| { + body.extend_from_slice(&chunk); + Ok::<_, PayloadError>(body) + }) +} + +#[test] +fn test_h1_v2() { + env_logger::init(); + let mut srv = TestServer::new(move || { + HttpService::build() + .finish(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .map(|_| ()) + }); + let response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); + + let request = srv.get().header("x-test", "111").send(); + let mut response = srv.block_on(request).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); + + let mut response = srv.block_on(srv.post().send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_connection_close() { + let mut srv = TestServer::new(move || { + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map(|_| ()) + }); + let response = srv.block_on(srv.get().close_connection().send()).unwrap(); + assert!(response.status().is_success()); +} + +#[test] +fn test_with_query_parameter() { + let mut srv = TestServer::new(move || { + HttpService::build() + .finish(|req: Request| { + if req.uri().query().unwrap().contains("qp=") { + ok::<_, ()>(Response::Ok().finish()) + } else { + ok::<_, ()>(Response::BadRequest().finish()) + } + }) + .map(|_| ()) + }); + + let request = srv.request(http::Method::GET, srv.url("/?qp=5")).send(); + let response = srv.block_on(request).unwrap(); + assert!(response.status().is_success()); +} diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs new file mode 100644 index 000000000..5777c5691 --- /dev/null +++ b/actix-http/tests/test_server.rs @@ -0,0 +1,933 @@ +use std::io::{Read, Write}; +use std::time::Duration; +use std::{net, thread}; + +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_http_test::TestServer; +use actix_server_config::ServerConfig; +use actix_service::{fn_cfg_factory, NewService}; +use bytes::{Bytes, BytesMut}; +use futures::future::{self, ok, Future}; +use futures::stream::{once, Stream}; + +use actix_http::body::Body; +use actix_http::error::PayloadError; +use actix_http::{ + body, error, http, http::header, Error, HttpMessage as HttpMessage2, HttpService, + KeepAlive, Request, Response, +}; + +fn load_body(stream: S) -> impl Future +where + S: Stream, +{ + stream.fold(BytesMut::new(), move |mut body, chunk| { + body.extend_from_slice(&chunk); + Ok::<_, PayloadError>(body) + }) +} + +#[test] +fn test_h1() { + let mut srv = TestServer::new(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_disconnect(1000) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); +} + +#[test] +fn test_h1_2() { + let mut srv = TestServer::new(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .client_timeout(1000) + .client_disconnect(1000) + .finish(|req: Request| { + assert_eq!(req.version(), http::Version::HTTP_11); + future::ok::<_, ()>(Response::Ok().finish()) + }) + .map(|_| ()) + }); + + let response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); +} + +#[cfg(feature = "ssl")] +fn ssl_acceptor( +) -> std::io::Result> { + use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; + // load ssl keys + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + builder.set_alpn_select_callback(|_, protos| { + const H2: &[u8] = b"\x02h2"; + if protos.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else { + Err(openssl::ssl::AlpnError::NOACK) + } + }); + builder.set_alpn_protos(b"\x02h2")?; + Ok(actix_server::ssl::OpensslAcceptor::new(builder.build())) +} + +#[cfg(feature = "ssl")] +#[test] +fn test_h2() -> std::io::Result<()> { + let openssl = ssl_acceptor()?; + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| future::ok::<_, Error>(Response::Ok().finish())) + .map_err(|_| ()), + ) + }); + + let response = srv.block_on(srv.sget().send()).unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[cfg(feature = "ssl")] +#[test] +fn test_h2_1() -> std::io::Result<()> { + let openssl = ssl_acceptor()?; + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .finish(|req: Request| { + assert_eq!(req.version(), http::Version::HTTP_2); + future::ok::<_, Error>(Response::Ok().finish()) + }) + .map_err(|_| ()), + ) + }); + + let response = srv.block_on(srv.sget().send()).unwrap(); + assert!(response.status().is_success()); + Ok(()) +} + +#[cfg(feature = "ssl")] +#[test] +fn test_h2_body() -> std::io::Result<()> { + let data = "HELLOWORLD".to_owned().repeat(64 * 1024); + let openssl = ssl_acceptor()?; + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|mut req: Request<_>| { + load_body(req.take_payload()) + .and_then(|body| Ok(Response::Ok().body(body))) + }) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.sget().send_body(data.clone())).unwrap(); + assert!(response.status().is_success()); + + let body = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(&body, data.as_bytes()); + Ok(()) +} + +#[test] +fn test_slow_request() { + let srv = TestServer::new(|| { + HttpService::build() + .client_timeout(100) + .finish(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 408 Request Timeout")); +} + +#[test] +fn test_http1_malformed_request() { + let srv = TestServer::new(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 400 Bad Request")); +} + +#[test] +fn test_http1_keepalive() { + let srv = TestServer::new(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); +} + +#[test] +fn test_http1_keepalive_timeout() { + let srv = TestServer::new(|| { + HttpService::build() + .keep_alive(1) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + thread::sleep(Duration::from_millis(1100)); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[test] +fn test_http1_keepalive_close() { + let srv = TestServer::new(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = + stream.write_all(b"GET /test/tests/test HTTP/1.1\r\nconnection: close\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[test] +fn test_http10_keepalive_default_close() { + let srv = TestServer::new(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[test] +fn test_http10_keepalive() { + let srv = TestServer::new(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream + .write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[test] +fn test_http1_keepalive_disabled() { + let srv = TestServer::new(|| { + HttpService::build() + .keep_alive(KeepAlive::Disabled) + .h1(|_| future::ok::<_, ()>(Response::Ok().finish())) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let mut data = vec![0; 1024]; + let res = stream.read(&mut data).unwrap(); + assert_eq!(res, 0); +} + +#[test] +fn test_content_length() { + use actix_http::http::{ + header::{HeaderName, HeaderValue}, + StatusCode, + }; + + let mut srv = TestServer::new(|| { + HttpService::build().h1(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv + .request(http::Method::GET, srv.url(&format!("/{}", i))) + .send(); + let response = srv.block_on(req).unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(http::Method::HEAD, srv.url(&format!("/{}", i))) + .send(); + let response = srv.block_on(req).unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(http::Method::GET, srv.url(&format!("/{}", i))) + .send(); + let response = srv.block_on(req).unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} + +#[test] +fn test_h2_content_length() { + use actix_http::http::{ + header::{HeaderName, HeaderValue}, + StatusCode, + }; + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, ()>(Response::new(statuses[indx])) + }) + .map_err(|_| ()), + ) + }); + + let header = HeaderName::from_static("content-length"); + let value = HeaderValue::from_static("0"); + + { + for i in 0..4 { + let req = srv + .request(http::Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = srv.block_on(req).unwrap(); + assert_eq!(response.headers().get(&header), None); + + let req = srv + .request(http::Method::HEAD, srv.surl(&format!("/{}", i))) + .send(); + let response = srv.block_on(req).unwrap(); + assert_eq!(response.headers().get(&header), None); + } + + for i in 4..6 { + let req = srv + .request(http::Method::GET, srv.surl(&format!("/{}", i))) + .send(); + let response = srv.block_on(req).unwrap(); + assert_eq!(response.headers().get(&header), Some(&value)); + } + } +} + +#[test] +fn test_h1_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + + let mut srv = TestServer::new(move || { + let data = data.clone(); + HttpService::build().h1(move |_| { + let mut builder = Response::Ok(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + future::ok::<_, ()>(builder.body(data.clone())) + }) + }); + + let mut response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +#[test] +fn test_h2_headers() { + let data = STR.repeat(10); + let data2 = data.clone(); + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::new(move || { + let data = data.clone(); + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build().h2(move |_| { + let mut builder = Response::Ok(); + for idx in 0..90 { + builder.header( + format!("X-TEST-{}", idx).as_str(), + "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ + TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", + ); + } + future::ok::<_, ()>(builder.body(data.clone())) + }).map_err(|_| ())) + }); + + let mut response = srv.block_on(srv.sget().send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from(data2)); +} + +const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + +#[test] +fn test_h1_body() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + }); + + let mut response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h2_body2() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| future::ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.sget().send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h1_head_empty() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); + + let mut response = srv.block_on(srv.head().send()).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h2_head_empty() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .finish(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.shead().send()).unwrap(); + assert!(response.status().is_success()); + assert_eq!(response.version(), http::Version::HTTP_2); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h1_head_binary() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(|_| { + ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).body(STR)) + }) + }); + + let mut response = srv.block_on(srv.head().send()).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h2_head_binary() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| { + ok::<_, ()>( + Response::Ok().content_length(STR.len() as u64).body(STR), + ) + }) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.shead().send()).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h1_head_binary2() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(|_| ok::<_, ()>(Response::Ok().body(STR))) + }); + + let response = srv.block_on(srv.head().send()).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[test] +fn test_h2_head_binary2() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| ok::<_, ()>(Response::Ok().body(STR))) + .map_err(|_| ()), + ) + }); + + let response = srv.block_on(srv.shead().send()).unwrap(); + assert!(response.status().is_success()); + + { + let len = response + .headers() + .get(http::header::CONTENT_LENGTH) + .unwrap(); + assert_eq!(format!("{}", STR.len()), len.to_str().unwrap()); + } +} + +#[test] +fn test_h1_body_length() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .body(Body::from_message(body::SizedStream::new(STR.len(), body))), + ) + }) + }); + + let mut response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h2_body_length() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| { + let body = once(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>(Response::Ok().body(Body::from_message( + body::SizedStream::new(STR.len(), body), + ))) + }) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.sget().send()).unwrap(); + assert!(response.status().is_success()); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h1_body_chunked_explicit() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(|_| { + let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + }); + + let mut response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h2_body_chunked_explicit() { + let openssl = ssl_acceptor().unwrap(); + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| { + let body = + once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.sget().send()).unwrap(); + assert!(response.status().is_success()); + assert!(!response.headers().contains_key(header::TRANSFER_ENCODING)); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + + // decode + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h1_body_chunked_implicit() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(|_| { + let body = once::<_, Error>(Ok(Bytes::from_static(STR.as_ref()))); + ok::<_, ()>(Response::Ok().streaming(body)) + }) + }); + + let mut response = srv.block_on(srv.get().send()).unwrap(); + assert!(response.status().is_success()); + assert_eq!( + response + .headers() + .get(header::TRANSFER_ENCODING) + .unwrap() + .to_str() + .unwrap(), + "chunked" + ); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert_eq!(bytes, Bytes::from_static(STR.as_ref())); +} + +#[test] +fn test_h1_response_http_error_handling() { + let mut srv = TestServer::new(|| { + HttpService::build().h1(fn_cfg_factory(|_: &ServerConfig| { + Ok::<_, ()>(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + }) + })) + }); + + let mut response = srv.block_on(srv.get().send()).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h2_response_http_error_handling() { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(fn_cfg_factory(|_: &ServerConfig| { + Ok::<_, ()>(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, ()>( + Response::Ok() + .header(http::header::CONTENT_TYPE, broken_header) + .body(STR), + ) + }) + })) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.sget().send()).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h1_service_error() { + let mut srv = TestServer::new(|| { + HttpService::build() + .h1(|_| Err::(error::ErrorBadRequest("error"))) + }); + + let mut response = srv.block_on(srv.get().send()).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} + +#[test] +fn test_h2_service_error() { + let openssl = ssl_acceptor().unwrap(); + + let mut srv = TestServer::new(move || { + openssl + .clone() + .map_err(|e| println!("Openssl error: {}", e)) + .and_then( + HttpService::build() + .h2(|_| Err::(error::ErrorBadRequest("error"))) + .map_err(|_| ()), + ) + }); + + let mut response = srv.block_on(srv.sget().send()).unwrap(); + assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR); + + // read response + let bytes = srv.block_on(load_body(response.take_payload())).unwrap(); + assert!(bytes.is_empty()); +} diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs new file mode 100644 index 000000000..d8942fb17 --- /dev/null +++ b/actix-http/tests/test_ws.rs @@ -0,0 +1,110 @@ +use std::io; + +use actix_codec::Framed; +use actix_http_test::TestServer; +use actix_server::Io; +use actix_service::{fn_service, NewService}; +use actix_utils::framed::IntoFramed; +use actix_utils::stream::TakeItem; +use bytes::{Bytes, BytesMut}; +use futures::future::{ok, Either}; +use futures::{Future, Sink, Stream}; +use tokio_tcp::TcpStream; + +use actix_http::{h1, ws, ResponseError, SendResponse, ServiceConfig}; + +fn ws_service(req: ws::Frame) -> impl Future { + match req { + ws::Frame::Ping(msg) => ok(ws::Message::Pong(msg)), + ws::Frame::Text(text) => { + let text = if let Some(pl) = text { + String::from_utf8(Vec::from(pl.as_ref())).unwrap() + } else { + String::new() + }; + ok(ws::Message::Text(text)) + } + ws::Frame::Binary(bin) => ok(ws::Message::Binary( + bin.map(|e| e.freeze()) + .unwrap_or_else(|| Bytes::from("")) + .into(), + )), + ws::Frame::Close(reason) => ok(ws::Message::Close(reason)), + _ => ok(ws::Message::Close(None)), + } +} + +#[test] +fn test_simple() { + let mut srv = TestServer::new(|| { + fn_service(|io: Io| Ok(io.into_parts().0)) + .and_then(IntoFramed::new(|| h1::Codec::new(ServiceConfig::default()))) + .and_then(TakeItem::new().map_err(|_| ())) + .and_then(|(req, framed): (_, Framed<_, _>)| { + // validate request + if let Some(h1::Message::Item(req)) = req { + match ws::verify_handshake(&req) { + Err(e) => { + // validation failed + Either::A( + SendResponse::send(framed, e.error_response()) + .map_err(|_| ()) + .map(|_| ()), + ) + } + Ok(_) => { + Either::B( + // send handshake response + SendResponse::send( + framed, + ws::handshake_response(&req).finish(), + ) + .map_err(|_| ()) + .and_then(|framed| { + // start websocket service + let framed = framed.into_framed(ws::Codec::new()); + ws::Transport::with(framed, ws_service) + .map_err(|_| ()) + }), + ) + } + } + } else { + panic!() + } + }) + }); + + // client service + let framed = srv.ws().unwrap(); + let framed = srv + .block_on(framed.send(ws::Message::Text("text".to_string()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text"))))); + + let framed = srv + .block_on(framed.send(ws::Message::Binary("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into()))) + ); + + let framed = srv + .block_on(framed.send(ws::Message::Ping("text".into()))) + .unwrap(); + let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into()))); + + let framed = srv + .block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))) + .unwrap(); + + let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap(); + assert_eq!( + item, + Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into()))) + ); +}