1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-02 05:18:44 +00:00

better ergonomics for WsClient::client()

This commit is contained in:
Nikolay Kim 2018-02-24 08:14:21 +03:00
parent fd31eb74c5
commit a855c8b2c9
3 changed files with 98 additions and 65 deletions

View file

@ -182,7 +182,7 @@ impl TestServer {
/// Connect to websocket server
pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> {
let url = self.url("/");
self.system.run_until_complete(WsClient::new(url).connect().unwrap())
self.system.run_until_complete(WsClient::new(url).connect())
}
/// Create `GET` request

View file

@ -31,9 +31,6 @@ use super::Message;
use super::frame::Frame;
use super::proto::{CloseCode, OpCode};
pub type WsClientFuture =
Future<Item=(WsClientReader, WsClientWriter), Error=WsClientError>;
/// Websocket client error
#[derive(Fail, Debug)]
@ -140,7 +137,7 @@ impl WsClient {
}
/// Set cookie for handshake request
pub fn cookie<'c>(mut self, cookie: Cookie<'c>) -> Self {
pub fn cookie(mut self, cookie: Cookie) -> Self {
self.request.cookie(cookie);
self
}
@ -165,49 +162,46 @@ impl WsClient {
}
/// Connect to websocket server and do ws handshake
pub fn connect(&mut self) -> Result<Box<WsClientFuture>, WsClientError> {
pub fn connect(&mut self) -> WsHandshake {
if let Some(e) = self.err.take() {
return Err(e)
WsHandshake::new(None, Some(e), &self.conn)
}
if let Some(e) = self.http_err.take() {
return Err(e.into())
}
// origin
if let Some(origin) = self.origin.take() {
self.request.set_header(header::ORIGIN, origin);
}
self.request.upgrade();
self.request.set_header(header::UPGRADE, "websocket");
self.request.set_header(header::CONNECTION, "upgrade");
self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
if let Some(protocols) = self.protocols.take() {
self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
}
let request = self.request.finish()?;
if request.uri().host().is_none() {
return Err(WsClientError::InvalidUrl)
}
if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return Err(WsClientError::InvalidUrl);
}
else if let Some(e) = self.http_err.take() {
WsHandshake::new(None, Some(e.into()), &self.conn)
} else {
return Err(WsClientError::InvalidUrl);
}
// origin
if let Some(origin) = self.origin.take() {
self.request.set_header(header::ORIGIN, origin);
}
// get connection and start handshake
Ok(Box::new(
self.conn.send(Connect(request.uri().clone()))
.map_err(|_| WsClientError::Disconnected)
.and_then(|res| match res {
Ok(stream) => Either::A(WsHandshake::new(stream, request)),
Err(err) => Either::B(FutErr(err.into())),
})
))
self.request.upgrade();
self.request.set_header(header::UPGRADE, "websocket");
self.request.set_header(header::CONNECTION, "upgrade");
self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
if let Some(protocols) = self.protocols.take() {
self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
}
let request = match self.request.finish() {
Ok(req) => req,
Err(err) => return WsHandshake::new(None, Some(err.into()), &self.conn),
};
if request.uri().host().is_none() {
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
}
if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return WsHandshake::new(
None, Some(WsClientError::InvalidUrl), &self.conn)
}
} else {
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
}
// start handshake
WsHandshake::new(Some(request), None, &self.conn)
}
}
}
@ -220,39 +214,53 @@ struct WsInner {
error_sent: bool,
}
struct WsHandshake {
pub struct WsHandshake {
inner: Option<WsInner>,
request: ClientRequest,
request: Option<ClientRequest>,
sent: bool,
key: String,
error: Option<WsClientError>,
stream: Option<Box<Future<Item=Result<Connection, WsClientError>, Error=WsClientError>>>,
}
impl WsHandshake {
fn new(conn: Connection, mut request: ClientRequest) -> WsHandshake {
fn new(request: Option<ClientRequest>,
err: Option<WsClientError>,
conn: &Addr<Unsync, ClientConnector>) -> WsHandshake
{
// Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455)
let sec_key: [u8; 16] = rand::random();
let key = base64::encode(&sec_key);
request.headers_mut().insert(
HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
HeaderValue::try_from(key.as_str()).unwrap());
if let Some(mut request) = request {
let stream = Box::new(
conn.send(Connect(request.uri().clone()))
.map(|res| res.map_err(|e| e.into()))
.map_err(|_| WsClientError::Disconnected));
let inner = WsInner {
conn: conn,
writer: HttpClientWriter::new(SharedBytes::default()),
parser: HttpResponseParser::default(),
parser_buf: BytesMut::new(),
closed: false,
error_sent: false,
};
request.headers_mut().insert(
HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
HeaderValue::try_from(key.as_str()).unwrap());
WsHandshake {
key: key,
inner: Some(inner),
request: request,
sent: false,
WsHandshake {
key: key,
inner: None,
request: Some(request),
sent: false,
error: err,
stream: Some(stream),
}
} else {
WsHandshake {
key: key,
inner: None,
request: None,
sent: false,
error: err,
stream: None,
}
}
}
}
@ -262,11 +270,36 @@ impl Future for WsHandshake {
type Error = WsClientError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(err) = self.error.take() {
return Err(err)
}
if self.stream.is_some() {
match self.stream.as_mut().unwrap().poll()? {
Async::Ready(result) => match result {
Ok(conn) => {
let inner = WsInner {
conn: conn,
writer: HttpClientWriter::new(SharedBytes::default()),
parser: HttpResponseParser::default(),
parser_buf: BytesMut::new(),
closed: false,
error_sent: false,
};
self.stream.take();
self.inner = Some(inner);
}
Err(err) => return Err(err),
},
Async::NotReady => return Ok(Async::NotReady)
}
}
let mut inner = self.inner.take().unwrap();
if !self.sent {
self.sent = true;
inner.writer.start(&mut self.request)?;
inner.writer.start(self.request.as_mut().unwrap())?;
}
if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) {
return Err(err.into())

View file

@ -65,7 +65,7 @@ use self::frame::Frame;
use self::proto::{hash_key, OpCode};
pub use self::proto::CloseCode;
pub use self::context::WebsocketContext;
pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsClientFuture};
pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake};
const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";