1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-04 14:28:50 +00:00

better ergonomics for WsClient::client()

This commit is contained in:
Nikolay Kim 2018-02-24 08:14:21 +03:00
parent fd31eb74c5
commit a855c8b2c9
3 changed files with 98 additions and 65 deletions

View file

@ -182,7 +182,7 @@ impl TestServer {
/// Connect to websocket server /// Connect to websocket server
pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> { pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> {
let url = self.url("/"); let url = self.url("/");
self.system.run_until_complete(WsClient::new(url).connect().unwrap()) self.system.run_until_complete(WsClient::new(url).connect())
} }
/// Create `GET` request /// Create `GET` request

View file

@ -31,9 +31,6 @@ use super::Message;
use super::frame::Frame; use super::frame::Frame;
use super::proto::{CloseCode, OpCode}; use super::proto::{CloseCode, OpCode};
pub type WsClientFuture =
Future<Item=(WsClientReader, WsClientWriter), Error=WsClientError>;
/// Websocket client error /// Websocket client error
#[derive(Fail, Debug)] #[derive(Fail, Debug)]
@ -140,7 +137,7 @@ impl WsClient {
} }
/// Set cookie for handshake request /// Set cookie for handshake request
pub fn cookie<'c>(mut self, cookie: Cookie<'c>) -> Self { pub fn cookie(mut self, cookie: Cookie) -> Self {
self.request.cookie(cookie); self.request.cookie(cookie);
self self
} }
@ -165,49 +162,46 @@ impl WsClient {
} }
/// Connect to websocket server and do ws handshake /// Connect to websocket server and do ws handshake
pub fn connect(&mut self) -> Result<Box<WsClientFuture>, WsClientError> { pub fn connect(&mut self) -> WsHandshake {
if let Some(e) = self.err.take() { if let Some(e) = self.err.take() {
return Err(e) WsHandshake::new(None, Some(e), &self.conn)
} }
if let Some(e) = self.http_err.take() { else if let Some(e) = self.http_err.take() {
return Err(e.into()) WsHandshake::new(None, Some(e.into()), &self.conn)
}
// origin
if let Some(origin) = self.origin.take() {
self.request.set_header(header::ORIGIN, origin);
}
self.request.upgrade();
self.request.set_header(header::UPGRADE, "websocket");
self.request.set_header(header::CONNECTION, "upgrade");
self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
if let Some(protocols) = self.protocols.take() {
self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
}
let request = self.request.finish()?;
if request.uri().host().is_none() {
return Err(WsClientError::InvalidUrl)
}
if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return Err(WsClientError::InvalidUrl);
}
} else { } else {
return Err(WsClientError::InvalidUrl); // origin
} if let Some(origin) = self.origin.take() {
self.request.set_header(header::ORIGIN, origin);
}
// get connection and start handshake self.request.upgrade();
Ok(Box::new( self.request.set_header(header::UPGRADE, "websocket");
self.conn.send(Connect(request.uri().clone())) self.request.set_header(header::CONNECTION, "upgrade");
.map_err(|_| WsClientError::Disconnected) self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
.and_then(|res| match res {
Ok(stream) => Either::A(WsHandshake::new(stream, request)), if let Some(protocols) = self.protocols.take() {
Err(err) => Either::B(FutErr(err.into())), self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
}) }
)) let request = match self.request.finish() {
Ok(req) => req,
Err(err) => return WsHandshake::new(None, Some(err.into()), &self.conn),
};
if request.uri().host().is_none() {
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
}
if let Some(scheme) = request.uri().scheme_part() {
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
return WsHandshake::new(
None, Some(WsClientError::InvalidUrl), &self.conn)
}
} else {
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
}
// start handshake
WsHandshake::new(Some(request), None, &self.conn)
}
} }
} }
@ -220,39 +214,53 @@ struct WsInner {
error_sent: bool, error_sent: bool,
} }
struct WsHandshake { pub struct WsHandshake {
inner: Option<WsInner>, inner: Option<WsInner>,
request: ClientRequest, request: Option<ClientRequest>,
sent: bool, sent: bool,
key: String, key: String,
error: Option<WsClientError>,
stream: Option<Box<Future<Item=Result<Connection, WsClientError>, Error=WsClientError>>>,
} }
impl WsHandshake { impl WsHandshake {
fn new(conn: Connection, mut request: ClientRequest) -> WsHandshake { fn new(request: Option<ClientRequest>,
err: Option<WsClientError>,
conn: &Addr<Unsync, ClientConnector>) -> WsHandshake
{
// Generate a random key for the `Sec-WebSocket-Key` header. // Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that, // a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455) // when decoded, is 16 bytes in length (RFC 6455)
let sec_key: [u8; 16] = rand::random(); let sec_key: [u8; 16] = rand::random();
let key = base64::encode(&sec_key); let key = base64::encode(&sec_key);
request.headers_mut().insert( if let Some(mut request) = request {
HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(), let stream = Box::new(
HeaderValue::try_from(key.as_str()).unwrap()); conn.send(Connect(request.uri().clone()))
.map(|res| res.map_err(|e| e.into()))
.map_err(|_| WsClientError::Disconnected));
let inner = WsInner { request.headers_mut().insert(
conn: conn, HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
writer: HttpClientWriter::new(SharedBytes::default()), HeaderValue::try_from(key.as_str()).unwrap());
parser: HttpResponseParser::default(),
parser_buf: BytesMut::new(),
closed: false,
error_sent: false,
};
WsHandshake { WsHandshake {
key: key, key: key,
inner: Some(inner), inner: None,
request: request, request: Some(request),
sent: false, sent: false,
error: err,
stream: Some(stream),
}
} else {
WsHandshake {
key: key,
inner: None,
request: None,
sent: false,
error: err,
stream: None,
}
} }
} }
} }
@ -262,11 +270,36 @@ impl Future for WsHandshake {
type Error = WsClientError; type Error = WsClientError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(err) = self.error.take() {
return Err(err)
}
if self.stream.is_some() {
match self.stream.as_mut().unwrap().poll()? {
Async::Ready(result) => match result {
Ok(conn) => {
let inner = WsInner {
conn: conn,
writer: HttpClientWriter::new(SharedBytes::default()),
parser: HttpResponseParser::default(),
parser_buf: BytesMut::new(),
closed: false,
error_sent: false,
};
self.stream.take();
self.inner = Some(inner);
}
Err(err) => return Err(err),
},
Async::NotReady => return Ok(Async::NotReady)
}
}
let mut inner = self.inner.take().unwrap(); let mut inner = self.inner.take().unwrap();
if !self.sent { if !self.sent {
self.sent = true; self.sent = true;
inner.writer.start(&mut self.request)?; inner.writer.start(self.request.as_mut().unwrap())?;
} }
if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) { if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) {
return Err(err.into()) return Err(err.into())

View file

@ -65,7 +65,7 @@ use self::frame::Frame;
use self::proto::{hash_key, OpCode}; use self::proto::{hash_key, OpCode};
pub use self::proto::CloseCode; pub use self::proto::CloseCode;
pub use self::context::WebsocketContext; pub use self::context::WebsocketContext;
pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsClientFuture}; pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake};
const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT"; const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY"; const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";