mirror of
https://github.com/actix/actix-web.git
synced 2025-01-04 14:28:50 +00:00
better ergonomics for WsClient::client()
This commit is contained in:
parent
fd31eb74c5
commit
a855c8b2c9
3 changed files with 98 additions and 65 deletions
|
@ -182,7 +182,7 @@ impl TestServer {
|
||||||
/// Connect to websocket server
|
/// Connect to websocket server
|
||||||
pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> {
|
pub fn ws(&mut self) -> Result<(WsClientReader, WsClientWriter), WsClientError> {
|
||||||
let url = self.url("/");
|
let url = self.url("/");
|
||||||
self.system.run_until_complete(WsClient::new(url).connect().unwrap())
|
self.system.run_until_complete(WsClient::new(url).connect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create `GET` request
|
/// Create `GET` request
|
||||||
|
|
159
src/ws/client.rs
159
src/ws/client.rs
|
@ -31,9 +31,6 @@ use super::Message;
|
||||||
use super::frame::Frame;
|
use super::frame::Frame;
|
||||||
use super::proto::{CloseCode, OpCode};
|
use super::proto::{CloseCode, OpCode};
|
||||||
|
|
||||||
pub type WsClientFuture =
|
|
||||||
Future<Item=(WsClientReader, WsClientWriter), Error=WsClientError>;
|
|
||||||
|
|
||||||
|
|
||||||
/// Websocket client error
|
/// Websocket client error
|
||||||
#[derive(Fail, Debug)]
|
#[derive(Fail, Debug)]
|
||||||
|
@ -140,7 +137,7 @@ impl WsClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set cookie for handshake request
|
/// Set cookie for handshake request
|
||||||
pub fn cookie<'c>(mut self, cookie: Cookie<'c>) -> Self {
|
pub fn cookie(mut self, cookie: Cookie) -> Self {
|
||||||
self.request.cookie(cookie);
|
self.request.cookie(cookie);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -165,49 +162,46 @@ impl WsClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to websocket server and do ws handshake
|
/// Connect to websocket server and do ws handshake
|
||||||
pub fn connect(&mut self) -> Result<Box<WsClientFuture>, WsClientError> {
|
pub fn connect(&mut self) -> WsHandshake {
|
||||||
if let Some(e) = self.err.take() {
|
if let Some(e) = self.err.take() {
|
||||||
return Err(e)
|
WsHandshake::new(None, Some(e), &self.conn)
|
||||||
}
|
}
|
||||||
if let Some(e) = self.http_err.take() {
|
else if let Some(e) = self.http_err.take() {
|
||||||
return Err(e.into())
|
WsHandshake::new(None, Some(e.into()), &self.conn)
|
||||||
}
|
|
||||||
|
|
||||||
// origin
|
|
||||||
if let Some(origin) = self.origin.take() {
|
|
||||||
self.request.set_header(header::ORIGIN, origin);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.request.upgrade();
|
|
||||||
self.request.set_header(header::UPGRADE, "websocket");
|
|
||||||
self.request.set_header(header::CONNECTION, "upgrade");
|
|
||||||
self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
|
|
||||||
|
|
||||||
if let Some(protocols) = self.protocols.take() {
|
|
||||||
self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
|
|
||||||
}
|
|
||||||
let request = self.request.finish()?;
|
|
||||||
|
|
||||||
if request.uri().host().is_none() {
|
|
||||||
return Err(WsClientError::InvalidUrl)
|
|
||||||
}
|
|
||||||
if let Some(scheme) = request.uri().scheme_part() {
|
|
||||||
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
|
|
||||||
return Err(WsClientError::InvalidUrl);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return Err(WsClientError::InvalidUrl);
|
// origin
|
||||||
}
|
if let Some(origin) = self.origin.take() {
|
||||||
|
self.request.set_header(header::ORIGIN, origin);
|
||||||
|
}
|
||||||
|
|
||||||
// get connection and start handshake
|
self.request.upgrade();
|
||||||
Ok(Box::new(
|
self.request.set_header(header::UPGRADE, "websocket");
|
||||||
self.conn.send(Connect(request.uri().clone()))
|
self.request.set_header(header::CONNECTION, "upgrade");
|
||||||
.map_err(|_| WsClientError::Disconnected)
|
self.request.set_header("SEC-WEBSOCKET-VERSION", "13");
|
||||||
.and_then(|res| match res {
|
|
||||||
Ok(stream) => Either::A(WsHandshake::new(stream, request)),
|
if let Some(protocols) = self.protocols.take() {
|
||||||
Err(err) => Either::B(FutErr(err.into())),
|
self.request.set_header("SEC-WEBSOCKET-PROTOCOL", protocols.as_str());
|
||||||
})
|
}
|
||||||
))
|
let request = match self.request.finish() {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(err) => return WsHandshake::new(None, Some(err.into()), &self.conn),
|
||||||
|
};
|
||||||
|
|
||||||
|
if request.uri().host().is_none() {
|
||||||
|
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
|
||||||
|
}
|
||||||
|
if let Some(scheme) = request.uri().scheme_part() {
|
||||||
|
if scheme != "http" && scheme != "https" && scheme != "ws" && scheme != "wss" {
|
||||||
|
return WsHandshake::new(
|
||||||
|
None, Some(WsClientError::InvalidUrl), &self.conn)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return WsHandshake::new(None, Some(WsClientError::InvalidUrl), &self.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// start handshake
|
||||||
|
WsHandshake::new(Some(request), None, &self.conn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,39 +214,53 @@ struct WsInner {
|
||||||
error_sent: bool,
|
error_sent: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WsHandshake {
|
pub struct WsHandshake {
|
||||||
inner: Option<WsInner>,
|
inner: Option<WsInner>,
|
||||||
request: ClientRequest,
|
request: Option<ClientRequest>,
|
||||||
sent: bool,
|
sent: bool,
|
||||||
key: String,
|
key: String,
|
||||||
|
error: Option<WsClientError>,
|
||||||
|
stream: Option<Box<Future<Item=Result<Connection, WsClientError>, Error=WsClientError>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WsHandshake {
|
impl WsHandshake {
|
||||||
fn new(conn: Connection, mut request: ClientRequest) -> WsHandshake {
|
fn new(request: Option<ClientRequest>,
|
||||||
|
err: Option<WsClientError>,
|
||||||
|
conn: &Addr<Unsync, ClientConnector>) -> WsHandshake
|
||||||
|
{
|
||||||
// Generate a random key for the `Sec-WebSocket-Key` header.
|
// Generate a random key for the `Sec-WebSocket-Key` header.
|
||||||
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
||||||
// when decoded, is 16 bytes in length (RFC 6455)
|
// when decoded, is 16 bytes in length (RFC 6455)
|
||||||
let sec_key: [u8; 16] = rand::random();
|
let sec_key: [u8; 16] = rand::random();
|
||||||
let key = base64::encode(&sec_key);
|
let key = base64::encode(&sec_key);
|
||||||
|
|
||||||
request.headers_mut().insert(
|
if let Some(mut request) = request {
|
||||||
HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
|
let stream = Box::new(
|
||||||
HeaderValue::try_from(key.as_str()).unwrap());
|
conn.send(Connect(request.uri().clone()))
|
||||||
|
.map(|res| res.map_err(|e| e.into()))
|
||||||
|
.map_err(|_| WsClientError::Disconnected));
|
||||||
|
|
||||||
let inner = WsInner {
|
request.headers_mut().insert(
|
||||||
conn: conn,
|
HeaderName::try_from("SEC-WEBSOCKET-KEY").unwrap(),
|
||||||
writer: HttpClientWriter::new(SharedBytes::default()),
|
HeaderValue::try_from(key.as_str()).unwrap());
|
||||||
parser: HttpResponseParser::default(),
|
|
||||||
parser_buf: BytesMut::new(),
|
|
||||||
closed: false,
|
|
||||||
error_sent: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
WsHandshake {
|
WsHandshake {
|
||||||
key: key,
|
key: key,
|
||||||
inner: Some(inner),
|
inner: None,
|
||||||
request: request,
|
request: Some(request),
|
||||||
sent: false,
|
sent: false,
|
||||||
|
error: err,
|
||||||
|
stream: Some(stream),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
WsHandshake {
|
||||||
|
key: key,
|
||||||
|
inner: None,
|
||||||
|
request: None,
|
||||||
|
sent: false,
|
||||||
|
error: err,
|
||||||
|
stream: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -262,11 +270,36 @@ impl Future for WsHandshake {
|
||||||
type Error = WsClientError;
|
type Error = WsClientError;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||||
|
if let Some(err) = self.error.take() {
|
||||||
|
return Err(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.stream.is_some() {
|
||||||
|
match self.stream.as_mut().unwrap().poll()? {
|
||||||
|
Async::Ready(result) => match result {
|
||||||
|
Ok(conn) => {
|
||||||
|
let inner = WsInner {
|
||||||
|
conn: conn,
|
||||||
|
writer: HttpClientWriter::new(SharedBytes::default()),
|
||||||
|
parser: HttpResponseParser::default(),
|
||||||
|
parser_buf: BytesMut::new(),
|
||||||
|
closed: false,
|
||||||
|
error_sent: false,
|
||||||
|
};
|
||||||
|
self.stream.take();
|
||||||
|
self.inner = Some(inner);
|
||||||
|
}
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
},
|
||||||
|
Async::NotReady => return Ok(Async::NotReady)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut inner = self.inner.take().unwrap();
|
let mut inner = self.inner.take().unwrap();
|
||||||
|
|
||||||
if !self.sent {
|
if !self.sent {
|
||||||
self.sent = true;
|
self.sent = true;
|
||||||
inner.writer.start(&mut self.request)?;
|
inner.writer.start(self.request.as_mut().unwrap())?;
|
||||||
}
|
}
|
||||||
if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) {
|
if let Err(err) = inner.writer.poll_completed(&mut inner.conn, false) {
|
||||||
return Err(err.into())
|
return Err(err.into())
|
||||||
|
|
|
@ -65,7 +65,7 @@ use self::frame::Frame;
|
||||||
use self::proto::{hash_key, OpCode};
|
use self::proto::{hash_key, OpCode};
|
||||||
pub use self::proto::CloseCode;
|
pub use self::proto::CloseCode;
|
||||||
pub use self::context::WebsocketContext;
|
pub use self::context::WebsocketContext;
|
||||||
pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsClientFuture};
|
pub use self::client::{WsClient, WsClientError, WsClientReader, WsClientWriter, WsHandshake};
|
||||||
|
|
||||||
const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
|
const SEC_WEBSOCKET_ACCEPT: &str = "SEC-WEBSOCKET-ACCEPT";
|
||||||
const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";
|
const SEC_WEBSOCKET_KEY: &str = "SEC-WEBSOCKET-KEY";
|
||||||
|
|
Loading…
Reference in a new issue