diff --git a/src/counter.rs b/src/counter.rs new file mode 100644 index 000000000..d2bd02266 --- /dev/null +++ b/src/counter.rs @@ -0,0 +1,76 @@ +use std::cell::Cell; +use std::rc::Rc; + +use futures::task::AtomicTask; + +#[derive(Clone)] +/// Simple counter with ability to notify task on reaching specific number +/// +/// Counter could be cloned, total ncount is shared across all clones. +pub struct Counter(Rc); + +struct CounterInner { + count: Cell, + max: usize, + task: AtomicTask, +} + +impl Counter { + /// Create `Counter` instance and set max value. + pub fn new(max: usize) -> Self { + Counter(Rc::new(CounterInner { + max, + count: Cell::new(0), + task: AtomicTask::new(), + })) + } + + pub fn get(&self) -> CounterGuard { + CounterGuard::new(self.0.clone()) + } + + pub fn check(&self) -> bool { + self.0.check() + } + + pub fn total(&self) -> usize { + self.0.count.get() + } +} + +pub struct CounterGuard(Rc); + +impl CounterGuard { + fn new(inner: Rc) -> Self { + inner.inc(); + CounterGuard(inner) + } +} + +impl Drop for CounterGuard { + fn drop(&mut self) { + self.0.dec(); + } +} + +impl CounterInner { + fn inc(&self) { + let num = self.count.get() + 1; + self.count.set(num); + if num == self.max { + self.task.register(); + } + } + + fn dec(&self) { + let num = self.count.get(); + self.count.set(num - 1); + if num == self.max { + self.task.notify(); + } + } + + fn check(&self) -> bool { + self.count.get() < self.max + } +} diff --git a/src/lib.rs b/src/lib.rs index ca4a65736..446222f70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,6 +56,7 @@ extern crate webpki; extern crate webpki_roots; pub mod connector; +pub mod counter; pub mod framed; pub mod resolver; pub mod server; diff --git a/src/server/mod.rs b/src/server/mod.rs index b080ffe54..0f8eb685b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -10,9 +10,6 @@ mod worker; pub use self::server::Server; pub use self::services::ServerServiceFactory; -#[allow(unused_imports)] -pub(crate) use self::worker::{Connections, ConnectionsGuard}; - /// Pause accepting incoming connections /// /// If socket contains some pending connection, they might be dropped. diff --git a/src/server/worker.rs b/src/server/worker.rs index eb8f7e48f..a741c35d2 100644 --- a/src/server/worker.rs +++ b/src/server/worker.rs @@ -1,12 +1,9 @@ -use std::cell::Cell; -use std::rc::Rc; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::{mem, net, time}; use futures::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use futures::sync::oneshot; -use futures::task::AtomicTask; use futures::{future, Async, Future, Poll, Stream}; use tokio_current_thread::spawn; use tokio_timer::{sleep, Delay}; @@ -17,6 +14,7 @@ use actix::{Arbiter, Message}; use super::accept::AcceptNotify; use super::services::{BoxedServerService, InternalServerServiceFactory, ServerMessage}; use super::Token; +use counter::Counter; pub(crate) enum WorkerCommand { Message(Conn), @@ -50,8 +48,8 @@ pub(crate) fn num_connections() -> usize { } thread_local! { - static MAX_CONNS_COUNTER: Connections = - Connections::new(MAX_CONNS.load(Ordering::Relaxed)); + static MAX_CONNS_COUNTER: Counter = + Counter::new(MAX_CONNS.load(Ordering::Relaxed)); } #[derive(Clone)] @@ -122,7 +120,7 @@ pub(crate) struct Worker { rx: UnboundedReceiver, services: Vec, availability: WorkerAvailability, - conns: Connections, + conns: Counter, factories: Vec>, state: WorkerState, } @@ -308,8 +306,7 @@ impl Future for Worker { match self.rx.poll() { // handle incoming tcp stream Ok(Async::Ready(Some(WorkerCommand::Message(msg)))) => { - match self.check_readiness() - { + match self.check_readiness() { Ok(true) => { let guard = self.conns.get(); spawn( @@ -320,7 +317,7 @@ impl Future for Worker { val }), ); - continue + continue; } Ok(false) => { trace!("Serveice is unsavailable"); @@ -330,12 +327,14 @@ impl Future for Worker { Err(idx) => { trace!("Serveice failed, restarting"); self.availability.set(false); - self.state = - WorkerState::Restarting(idx, self.factories[idx].create()); + self.state = WorkerState::Restarting( + idx, + self.factories[idx].create(), + ); } } return self.poll(); - }, + } // `StopWorker` message handler Ok(Async::Ready(Some(WorkerCommand::Stop(graceful, tx)))) => { self.availability.set(false); @@ -379,71 +378,3 @@ impl Future for Worker { Ok(Async::NotReady) } } - -#[derive(Clone)] -pub(crate) struct Connections(Rc); - -struct ConnectionsInner { - count: Cell, - maxconn: usize, - task: AtomicTask, -} - -impl Connections { - pub fn new(maxconn: usize) -> Self { - Connections(Rc::new(ConnectionsInner { - maxconn, - count: Cell::new(0), - task: AtomicTask::new(), - })) - } - - pub fn get(&self) -> ConnectionsGuard { - ConnectionsGuard::new(self.0.clone()) - } - - pub fn check(&self) -> bool { - self.0.check() - } - - pub fn total(&self) -> usize { - self.0.count.get() - } -} - -pub(crate) struct ConnectionsGuard(Rc); - -impl ConnectionsGuard { - fn new(inner: Rc) -> Self { - inner.inc(); - ConnectionsGuard(inner) - } -} - -impl Drop for ConnectionsGuard { - fn drop(&mut self) { - self.0.dec(); - } -} - -impl ConnectionsInner { - fn inc(&self) { - let num = self.count.get() + 1; - self.count.set(num); - if num == self.maxconn { - self.task.register(); - } - } - - fn dec(&self) { - let num = self.count.get(); - self.count.set(num - 1); - if num == self.maxconn { - self.task.notify(); - } - } - - fn check(&self) -> bool { - self.count.get() < self.maxconn - } -} diff --git a/src/ssl/mod.rs b/src/ssl/mod.rs index 8d56a8913..f512ab299 100644 --- a/src/ssl/mod.rs +++ b/src/ssl/mod.rs @@ -1,7 +1,7 @@ //! SSL Services use std::sync::atomic::{AtomicUsize, Ordering}; -use super::server::Connections; +use super::counter::Counter; #[cfg(feature = "ssl")] mod openssl; @@ -21,7 +21,7 @@ pub fn max_concurrent_ssl_connect(num: usize) { } thread_local! { - static MAX_CONN_COUNTER: Connections = Connections::new(MAX_CONN.load(Ordering::Relaxed)); + static MAX_CONN_COUNTER: Counter = Counter::new(MAX_CONN.load(Ordering::Relaxed)); } // #[cfg(feature = "tls")] diff --git a/src/ssl/openssl.rs b/src/ssl/openssl.rs index 8a84062d7..17083c581 100644 --- a/src/ssl/openssl.rs +++ b/src/ssl/openssl.rs @@ -7,7 +7,7 @@ use tokio_openssl::{AcceptAsync, ConnectAsync, SslAcceptorExt, SslConnectorExt, use super::MAX_CONN_COUNTER; use connector::ConnectionInfo; -use server::{Connections, ConnectionsGuard}; +use counter::{Counter, CounterGuard}; use service::{NewService, Service}; /// Support `SSL` connections via openssl package @@ -59,7 +59,7 @@ impl NewService for OpensslAcceptor { pub struct OpensslAcceptorService { acceptor: SslAcceptor, io: PhantomData, - conns: Connections, + conns: Counter, } impl Service for OpensslAcceptorService { @@ -89,7 +89,7 @@ where T: AsyncRead + AsyncWrite, { fut: AcceptAsync, - _guard: ConnectionsGuard, + _guard: CounterGuard, } impl Future for OpensslAcceptorServiceFut {