diff --git a/src/client/mod.rs b/src/client/mod.rs index 7fce930fc..dad0c3f38 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,5 +1,7 @@ mod parser; +mod request; mod response; +pub use self::request::{ClientRequest, ClientRequestBuilder}; pub use self::response::ClientResponse; pub use self::parser::{HttpResponseParser, HttpResponseParserError}; diff --git a/src/client/parser.rs b/src/client/parser.rs index d072f69b4..73ef2278c 100644 --- a/src/client/parser.rs +++ b/src/client/parser.rs @@ -17,7 +17,7 @@ use super::ClientResponse; const MAX_BUFFER_SIZE: usize = 131_072; const MAX_HEADERS: usize = 96; - +#[derive(Default)] pub struct HttpResponseParser { payload: Option, } @@ -41,11 +41,6 @@ pub enum HttpResponseParserError { } impl HttpResponseParser { - pub fn new() -> HttpResponseParser { - HttpResponseParser { - payload: None, - } - } fn decode(&mut self, buf: &mut BytesMut) -> Result { if let Some(ref mut payload) = self.payload { diff --git a/src/client/request.rs b/src/client/request.rs new file mode 100644 index 000000000..c441947f8 --- /dev/null +++ b/src/client/request.rs @@ -0,0 +1,384 @@ +use std::{fmt, mem}; +use std::io::Write; + +use cookie::{Cookie, CookieJar}; +use bytes::{BytesMut, BufMut}; +use http::{HeaderMap, Method, Version, Uri, HttpTryFrom, Error as HttpError}; +use http::header::{self, HeaderName, HeaderValue}; +use serde_json; +use serde::Serialize; + +use body::Body; +use error::Error; +use headers::ContentEncoding; + + +pub struct ClientRequest { + uri: Uri, + method: Method, + version: Version, + headers: HeaderMap, + body: Body, + chunked: Option, + encoding: ContentEncoding, +} + +impl Default for ClientRequest { + + fn default() -> ClientRequest { + ClientRequest { + uri: Uri::default(), + method: Method::default(), + version: Version::HTTP_11, + headers: HeaderMap::with_capacity(16), + body: Body::Empty, + chunked: None, + encoding: ContentEncoding::Auto, + } + } +} + +impl ClientRequest { + + /// Create client request builder + pub fn build() -> ClientRequestBuilder { + ClientRequestBuilder { + request: Some(ClientRequest::default()), + err: None, + cookies: None, + } + } + + /// Get the request uri + #[inline] + pub fn uri(&self) -> &Uri { + &self.uri + } + + /// Set client request uri + #[inline] + pub fn set_uri(&mut self, uri: Uri) { + self.uri = uri + } + + /// Get the request method + #[inline] + pub fn method(&self) -> &Method { + &self.method + } + + /// Set http `Method` for the request + #[inline] + pub fn set_method(&mut self, method: Method) { + self.method = method + } + + /// Get the headers from the request + #[inline] + pub fn headers(&self) -> &HeaderMap { + &self.headers + } + + /// Get a mutable reference to the headers + #[inline] + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.headers + } + + /// Get body os this response + #[inline] + pub fn body(&self) -> &Body { + &self.body + } + + /// Set a body + pub fn set_body>(&mut self, body: B) { + self.body = body.into(); + } + + /// Set a body and return previous body value + pub fn replace_body>(&mut self, body: B) -> Body { + mem::replace(&mut self.body, body.into()) + } +} + +impl fmt::Debug for ClientRequest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let res = write!(f, "\nClientRequest {:?} {}:{}\n", + self.version, self.method, self.uri); + let _ = write!(f, " headers:\n"); + for key in self.headers.keys() { + let vals: Vec<_> = self.headers.get_all(key).iter().collect(); + if vals.len() > 1 { + let _ = write!(f, " {:?}: {:?}\n", key, vals); + } else { + let _ = write!(f, " {:?}: {:?}\n", key, vals[0]); + } + } + res + } +} + + +pub struct ClientRequestBuilder { + request: Option, + err: Option, + cookies: Option, +} + +impl ClientRequestBuilder { + /// Set HTTP uri of request. + #[inline] + pub fn uri(&mut self, uri: U) -> &mut Self where Uri: HttpTryFrom { + match Uri::try_from(uri) { + Ok(uri) => { + // set request host header + if let Some(host) = uri.host() { + self.set_header(header::HOST, host); + } + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.uri = uri; + } + }, + Err(e) => self.err = Some(e.into(),), + } + self + } + + /// Set HTTP method of this request. + #[inline] + pub fn method(&mut self, method: Method) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.method = method; + } + self + } + + /// Set HTTP version of this request. + /// + /// By default requests's http version depends on network stream + #[inline] + pub fn version(&mut self, version: Version) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.version = version; + } + self + } + + /// Set a header. + /// + /// ```rust + /// # extern crate http; + /// # extern crate actix_web; + /// # use actix_web::client::*; + /// # + /// use http::header; + /// + /// fn main() { + /// let req = ClientRequest::build() + /// .header("X-TEST", "value") + /// .header(header::CONTENT_TYPE, "application/json") + /// .finish().unwrap(); + /// } + /// ``` + pub fn header(&mut self, key: K, value: V) -> &mut Self + where HeaderName: HttpTryFrom, HeaderValue: HttpTryFrom + { + if let Some(parts) = parts(&mut self.request, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => { + match HeaderValue::try_from(value) { + Ok(value) => { parts.headers.append(key, value); } + Err(e) => self.err = Some(e.into()), + } + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Replace a header. + pub fn set_header(&mut self, key: K, value: V) -> &mut Self + where HeaderName: HttpTryFrom, HeaderValue: HttpTryFrom + { + if let Some(parts) = parts(&mut self.request, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => { + match HeaderValue::try_from(value) { + Ok(value) => { parts.headers.insert(key, value); } + Err(e) => self.err = Some(e.into()), + } + }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set content encoding. + /// + /// By default `ContentEncoding::Identity` is used. + #[inline] + pub fn content_encoding(&mut self, enc: ContentEncoding) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.encoding = enc; + } + self + } + + /// Enables automatic chunked transfer encoding + #[inline] + pub fn chunked(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.request, &self.err) { + parts.chunked = Some(true); + } + self + } + + /// Set request's content type + #[inline] + pub fn content_type(&mut self, value: V) -> &mut Self + where HeaderValue: HttpTryFrom + { + if let Some(parts) = parts(&mut self.request, &self.err) { + match HeaderValue::try_from(value) { + Ok(value) => { parts.headers.insert(header::CONTENT_TYPE, value); }, + Err(e) => self.err = Some(e.into()), + }; + } + self + } + + /// Set content length + #[inline] + pub fn content_length(&mut self, len: u64) -> &mut Self { + let mut wrt = BytesMut::new().writer(); + let _ = write!(wrt, "{}", len); + self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze()) + } + + /// Set a cookie + /// + /// ```rust + /// # extern crate actix_web; + /// # use actix_web::*; + /// # use actix_web::httpcodes::*; + /// # + /// use actix_web::headers::Cookie; + /// + /// fn index(req: HttpRequest) -> Result { + /// Ok(HTTPOk.build() + /// .cookie( + /// Cookie::build("name", "value") + /// .domain("www.rust-lang.org") + /// .path("/") + /// .secure(true) + /// .http_only(true) + /// .finish()) + /// .finish()?) + /// } + /// fn main() {} + /// ``` + pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Remove cookie, cookie has to be cookie from `HttpRequest::cookies()` method. + pub fn del_cookie<'a>(&mut self, cookie: &Cookie<'a>) -> &mut Self { + { + if self.cookies.is_none() { + self.cookies = Some(CookieJar::new()) + } + let jar = self.cookies.as_mut().unwrap(); + let cookie = cookie.clone().into_owned(); + jar.add_original(cookie.clone()); + jar.remove(cookie); + } + self + } + + /// This method calls provided closure with builder reference if value is true. + pub fn if_true(&mut self, value: bool, f: F) -> &mut Self + where F: FnOnce(&mut ClientRequestBuilder) + { + if value { + f(self); + } + self + } + + /// This method calls provided closure with builder reference if value is Some. + pub fn if_some(&mut self, value: Option, f: F) -> &mut Self + where F: FnOnce(T, &mut ClientRequestBuilder) + { + if let Some(val) = value { + f(val, self); + } + self + } + + /// Set a body and generate `ClientRequest`. + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn body>(&mut self, body: B) -> Result { + if let Some(e) = self.err.take() { + return Err(e) + } + + let mut request = self.request.take().expect("cannot reuse request builder"); + + // set cookies + if let Some(ref jar) = self.cookies { + for cookie in jar.delta() { + request.headers.append( + header::SET_COOKIE, + HeaderValue::from_str(&cookie.to_string())?); + } + } + request.body = body.into(); + Ok(request) + } + + /// Set a json body and generate `ClientRequest` + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn json(&mut self, value: T) -> Result { + let body = serde_json::to_string(&value)?; + + let contains = if let Some(parts) = parts(&mut self.request, &self.err) { + parts.headers.contains_key(header::CONTENT_TYPE) + } else { + true + }; + if !contains { + self.header(header::CONTENT_TYPE, "application/json"); + } + + Ok(self.body(body)?) + } + + /// Set an empty body and generate `ClientRequest` + /// + /// `ClientRequestBuilder` can not be used after this call. + pub fn finish(&mut self) -> Result { + self.body(Body::Empty) + } +} + +#[inline] +fn parts<'a>(parts: &'a mut Option, err: &Option) + -> Option<&'a mut ClientRequest> +{ + if err.is_some() { + return None + } + parts.as_mut() +} diff --git a/src/httpresponse.rs b/src/httpresponse.rs index 20607b95b..ebcaefe9d 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -83,37 +83,37 @@ impl HttpResponse { self.get_ref().error.as_ref() } - /// Get the HTTP version of this response. + /// Get the HTTP version of this response #[inline] pub fn version(&self) -> Option { self.get_ref().version } - /// Get the headers from the response. + /// Get the headers from the response #[inline] pub fn headers(&self) -> &HeaderMap { &self.get_ref().headers } - /// Get a mutable reference to the headers. + /// Get a mutable reference to the headers #[inline] pub fn headers_mut(&mut self) -> &mut HeaderMap { &mut self.get_mut().headers } - /// Get the status from the server. + /// Get the response status code #[inline] pub fn status(&self) -> StatusCode { self.get_ref().status } - /// Set the `StatusCode` for this response. + /// Set the `StatusCode` for this response #[inline] pub fn status_mut(&mut self) -> &mut StatusCode { &mut self.get_mut().status } - /// Get custom reason for the response. + /// Get custom reason for the response #[inline] pub fn reason(&self) -> &str { if let Some(reason) = self.get_ref().reason { @@ -123,7 +123,7 @@ impl HttpResponse { } } - /// Set the custom reason for the response. + /// Set the custom reason for the response #[inline] pub fn set_reason(&mut self, reason: &'static str) -> &mut Self { self.get_mut().reason = Some(reason); diff --git a/src/lib.rs b/src/lib.rs index 5bf8020cf..02e2b1b14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,7 +109,8 @@ mod resource; mod handler; mod pipeline; -mod client; +#[doc(hidden)] +pub mod client; pub mod fs; pub mod ws; diff --git a/src/ws/client.rs b/src/ws/client.rs index c4fa762b3..2e6094e6c 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -1,28 +1,27 @@ //! Http client request -use std::{fmt, io, str}; +use std::{io, str}; use std::rc::Rc; use std::time::Duration; use std::cell::UnsafeCell; use base64; use rand; -use cookie::{Cookie, CookieJar}; +use cookie::Cookie; use bytes::BytesMut; -use http::{Method, Version, HeaderMap, HttpTryFrom, StatusCode, Error as HttpError}; +use http::{HttpTryFrom, StatusCode, Error as HttpError}; use http::header::{self, HeaderName, HeaderValue}; -use url::Url; use sha1::Sha1; use futures::{Async, Future, Poll, Stream}; // use futures::unsync::oneshot; use tokio_core::net::TcpStream; -use body::{Body, Binary}; +use body::Binary; use error::UrlParseError; -use headers::ContentEncoding; use server::shared::SharedBytes; use server::{utils, IoStream}; -use client::{HttpResponseParser, HttpResponseParserError}; +use client::{ClientRequest, ClientRequestBuilder, + HttpResponseParser, HttpResponseParserError}; use super::Message; use super::proto::{CloseCode, OpCode}; @@ -91,36 +90,23 @@ type WsFuture = Future, WsWriter), Error=WsClientError>; /// Websockt client pub struct WsClient { - request: Option, + request: ClientRequestBuilder, err: Option, http_err: Option, - cookies: Option, origin: Option, protocols: Option, } impl WsClient { - pub fn new>(url: S) -> WsClient { + pub fn new>(uri: S) -> WsClient { let mut cl = WsClient { - request: None, + request: ClientRequest::build(), err: None, http_err: None, - cookies: None, origin: None, protocols: None }; - - match Url::parse(url.as_ref()) { - Ok(url) => { - if url.scheme() != "http" && url.scheme() != "https" && - url.scheme() != "ws" && url.scheme() != "wss" || !url.has_host() { - cl.err = Some(WsClientError::InvalidUrl); - } else { - cl.request = Some(ClientRequest::new(Method::GET, url)); - } - }, - Err(err) => cl.err = Some(err.into()), - } + cl.request.uri(uri.as_ref()); cl } @@ -136,13 +122,7 @@ impl WsClient { } pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self { - if self.cookies.is_none() { - let mut jar = CookieJar::new(); - jar.add(cookie.into_owned()); - self.cookies = Some(jar) - } else { - self.cookies.as_mut().unwrap().add(cookie.into_owned()); - } + self.request.cookie(cookie); self } @@ -158,20 +138,9 @@ impl WsClient { } pub fn header(&mut self, key: K, value: V) -> &mut Self - where HeaderName: HttpTryFrom, - HeaderValue: HttpTryFrom + where HeaderName: HttpTryFrom, HeaderValue: HttpTryFrom { - if let Some(parts) = parts(&mut self.request, &self.err, &self.http_err) { - match HeaderName::try_from(key) { - Ok(key) => { - match HeaderValue::try_from(value) { - Ok(value) => { parts.headers.append(key, value); } - Err(e) => self.http_err = Some(e.into()), - } - }, - Err(e) => self.http_err = Some(e.into()), - }; - } + self.request.header(key, value); self } @@ -182,37 +151,35 @@ impl WsClient { if let Some(e) = self.http_err.take() { return Err(e.into()) } - let mut request = self.request.take().expect("cannot reuse request builder"); - - // headers - if let Some(ref jar) = self.cookies { - for cookie in jar.delta() { - request.headers.append( - header::SET_COOKIE, - HeaderValue::from_str(&cookie.to_string()).map_err(HttpError::from)?); - } - } // origin if let Some(origin) = self.origin.take() { - request.headers.insert(header::ORIGIN, origin); + self.request.set_header(header::ORIGIN, origin); } - request.headers.insert(header::UPGRADE, HeaderValue::from_static("websocket")); - request.headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade")); - request.headers.insert( - HeaderName::try_from("SEC-WEBSOCKET-VERSION").unwrap(), - HeaderValue::from_static("13")); + self.request.set_header(header::UPGRADE, "websocket"); + self.request.set_header(header::CONNECTION, "upgrade"); + self.request.set_header("SEC-WEBSOCKET-VERSION", "13"); if let Some(protocols) = self.protocols.take() { - request.headers.insert( - HeaderName::try_from("SEC-WEBSOCKET-PROTOCOL").unwrap(), - HeaderValue::try_from(protocols.as_str()).unwrap()); + self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str()); + } + let request = self.request.finish()?; + + if request.uri().host().is_none() { + return Err(WsClientError::InvalidUrl) + } + if let Some(scheme) = request.uri().scheme_part() { + if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" { + return Err(WsClientError::InvalidUrl); + } + } else { + return Err(WsClientError::InvalidUrl); } let connect = TcpConnector::new( - request.url.host_str().unwrap(), - request.url.port().unwrap_or(80), Duration::from_secs(5)); + request.uri().host().unwrap(), + request.uri().port().unwrap_or(80), Duration::from_secs(5)); Ok(Box::new( connect @@ -221,60 +188,6 @@ impl WsClient { } } -#[inline] -fn parts<'a>(parts: &'a mut Option, - err: &Option, - http_err: &Option) -> Option<&'a mut ClientRequest> -{ - if err.is_some() || http_err.is_some() { - return None - } - parts.as_mut() -} - -pub(crate) struct ClientRequest { - pub url: Url, - pub method: Method, - pub version: Version, - pub headers: HeaderMap, - pub body: Body, - pub chunked: Option, - pub encoding: ContentEncoding, -} - -impl ClientRequest { - - #[inline] - fn new(method: Method, url: Url) -> ClientRequest { - ClientRequest { - url: url, - method: method, - version: Version::HTTP_11, - headers: HeaderMap::with_capacity(16), - body: Body::Empty, - chunked: None, - encoding: ContentEncoding::Auto, - } - } -} - -impl fmt::Debug for ClientRequest { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = write!(f, "\nClientRequest {:?} {}:{}\n", - self.version, self.method, self.url); - let _ = write!(f, " headers:\n"); - for key in self.headers.keys() { - let vals: Vec<_> = self.headers.get_all(key).iter().collect(); - if vals.len() > 1 { - let _ = write!(f, " {:?}: {:?}\n", key, vals); - } else { - let _ = write!(f, " {:?}: {:?}\n", key, vals[0]); - } - } - res - } -} - struct WsInner { stream: T, writer: Writer, @@ -299,14 +212,14 @@ impl WsHandshake { let sec_key: [u8; 16] = rand::random(); let key = base64::encode(&sec_key); - request.headers.insert( + request.headers_mut().insert( HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), HeaderValue::try_from(key.as_str()).unwrap()); let inner = WsInner { stream: stream, writer: Writer::new(SharedBytes::default()), - parser: HttpResponseParser::new(), + parser: HttpResponseParser::default(), parser_buf: BytesMut::new(), closed: false, error_sent: false, @@ -370,7 +283,7 @@ impl Future for WsHandshake { { // ... field is constructed by concatenating /key/ ... // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) - const WS_GUID: &'static [u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; let mut sha1 = Sha1::new(); sha1.update(self.key.as_ref()); sha1.update(WS_GUID); diff --git a/src/ws/writer.rs b/src/ws/writer.rs index 57802b341..0cee01811 100644 --- a/src/ws/writer.rs +++ b/src/ws/writer.rs @@ -9,7 +9,7 @@ use body::Binary; use server::{WriterState, MAX_WRITE_BUFFER_SIZE}; use server::shared::SharedBytes; -use super::client::ClientRequest; +use client::ClientRequest; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific @@ -82,17 +82,17 @@ impl Writer { // render message { let buffer = self.buffer.get_mut(); - buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE); + buffer.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE); // status line // helpers::write_status_line(version, msg.status().as_u16(), &mut buffer); // buffer.extend_from_slice(msg.reason().as_bytes()); buffer.extend_from_slice(b"GET "); - buffer.extend_from_slice(msg.url.path().as_ref()); + buffer.extend_from_slice(msg.uri().path().as_ref()); buffer.extend_from_slice(b" HTTP/1.1\r\n"); // write headers - for (key, value) in &msg.headers { + for (key, value) in msg.headers() { let v = value.as_ref(); let k = key.as_str().as_bytes(); buffer.reserve(k.len() + v.len() + 4);