Massive RPC refactoring

This commit is contained in:
Alex Auvolat 2020-04-18 19:21:34 +02:00
parent 3f40ef149f
commit f41583e1b7
10 changed files with 577 additions and 451 deletions

View file

@ -9,7 +9,6 @@ use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn}; use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode}; use hyper::{Body, Method, Request, Response, Server, StatusCode};
use crate::block::*;
use crate::data::*; use crate::data::*;
use crate::error::Error; use crate::error::Error;
use crate::http_util::*; use crate::http_util::*;
@ -151,7 +150,9 @@ async fn handle_put(
let mut next_offset = first_block.len(); let mut next_offset = first_block.len();
let mut put_curr_version_block = let mut put_curr_version_block =
put_block_meta(garage.clone(), &version, 0, first_block_hash.clone()); put_block_meta(garage.clone(), &version, 0, first_block_hash.clone());
let mut put_curr_block = rpc_put_block(&garage.system, first_block_hash, first_block); let mut put_curr_block = garage
.block_manager
.rpc_put_block(first_block_hash, first_block);
loop { loop {
let (_, _, next_block) = let (_, _, next_block) =
@ -165,7 +166,7 @@ async fn handle_put(
next_offset as u64, next_offset as u64,
block_hash.clone(), block_hash.clone(),
); );
put_curr_block = rpc_put_block(&garage.system, block_hash, block); put_curr_block = garage.block_manager.rpc_put_block(block_hash, block);
next_offset += block_len; next_offset += block_len;
} else { } else {
break; break;
@ -300,7 +301,7 @@ async fn handle_get(
Ok(resp_builder.body(body)?) Ok(resp_builder.body(body)?)
} }
ObjectVersionData::FirstBlock(first_block_hash) => { ObjectVersionData::FirstBlock(first_block_hash) => {
let read_first_block = rpc_get_block(&garage.system, &first_block_hash); let read_first_block = garage.block_manager.rpc_get_block(&first_block_hash);
let get_next_blocks = garage.version_table.get(&last_v.uuid, &EmptySortKey); let get_next_blocks = garage.version_table.get(&last_v.uuid, &EmptySortKey);
let (first_block, version) = futures::try_join!(read_first_block, get_next_blocks)?; let (first_block, version) = futures::try_join!(read_first_block, get_next_blocks)?;
@ -323,7 +324,11 @@ async fn handle_get(
if let Some(data) = data_opt { if let Some(data) = data_opt {
Ok(Bytes::from(data)) Ok(Bytes::from(data))
} else { } else {
rpc_get_block(&garage.system, &hash).await.map(Bytes::from) garage
.block_manager
.rpc_get_block(&hash)
.await
.map(Bytes::from)
} }
} }
}) })

View file

@ -5,6 +5,7 @@ use std::time::Duration;
use arc_swap::ArcSwapOption; use arc_swap::ArcSwapOption;
use futures::future::*; use futures::future::*;
use futures::stream::*; use futures::stream::*;
use serde::{Deserialize, Serialize};
use tokio::fs; use tokio::fs;
use tokio::prelude::*; use tokio::prelude::*;
use tokio::sync::{watch, Mutex}; use tokio::sync::{watch, Mutex};
@ -15,22 +16,40 @@ use crate::error::Error;
use crate::membership::System; use crate::membership::System;
use crate::proto::*; use crate::proto::*;
use crate::rpc_client::*; use crate::rpc_client::*;
use crate::rpc_server::*;
use crate::server::Garage; use crate::server::Garage;
const NEED_BLOCK_QUERY_TIMEOUT: Duration = Duration::from_secs(5); const NEED_BLOCK_QUERY_TIMEOUT: Duration = Duration::from_secs(5);
const RESYNC_RETRY_TIMEOUT: Duration = Duration::from_secs(10); const RESYNC_RETRY_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Serialize, Deserialize)]
pub enum Message {
Ok,
GetBlock(Hash),
PutBlock(PutBlockMessage),
NeedBlockQuery(Hash),
NeedBlockReply(bool),
}
impl RpcMessage for Message {}
pub struct BlockManager { pub struct BlockManager {
pub data_dir: PathBuf, pub data_dir: PathBuf,
pub rc: sled::Tree, pub rc: sled::Tree,
pub resync_queue: sled::Tree, pub resync_queue: sled::Tree,
pub lock: Mutex<()>, pub lock: Mutex<()>,
pub system: Arc<System>, pub system: Arc<System>,
rpc_client: Arc<RpcClient<Message>>,
pub garage: ArcSwapOption<Garage>, pub garage: ArcSwapOption<Garage>,
} }
impl BlockManager { impl BlockManager {
pub fn new(db: &sled::Db, data_dir: PathBuf, system: Arc<System>) -> Arc<Self> { pub fn new(
db: &sled::Db,
data_dir: PathBuf,
system: Arc<System>,
rpc_server: &mut RpcServer,
) -> Arc<Self> {
let rc = db let rc = db
.open_tree("block_local_rc") .open_tree("block_local_rc")
.expect("Unable to open block_local_rc tree"); .expect("Unable to open block_local_rc tree");
@ -40,14 +59,38 @@ impl BlockManager {
.open_tree("block_local_resync_queue") .open_tree("block_local_resync_queue")
.expect("Unable to open block_local_resync_queue tree"); .expect("Unable to open block_local_resync_queue tree");
Arc::new(Self { let rpc_path = "block_manager";
let rpc_client = system.rpc_client::<Message>(rpc_path);
let block_manager = Arc::new(Self {
rc, rc,
resync_queue, resync_queue,
data_dir, data_dir,
lock: Mutex::new(()), lock: Mutex::new(()),
system, system,
rpc_client,
garage: ArcSwapOption::from(None), garage: ArcSwapOption::from(None),
}) });
block_manager
.clone()
.register_handler(rpc_server, rpc_path.into());
block_manager
}
fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) {
rpc_server.add_handler::<Message, _, _>(path, move |msg, _addr| {
let self2 = self.clone();
async move {
match msg {
Message::PutBlock(m) => self2.write_block(&m.hash, &m.data).await,
Message::GetBlock(h) => self2.read_block(&h).await,
Message::NeedBlockQuery(h) => {
self2.need_block(&h).await.map(Message::NeedBlockReply)
}
_ => Err(Error::Message(format!("Invalid RPC"))),
}
}
});
} }
pub async fn spawn_background_worker(self: Arc<Self>) { pub async fn spawn_background_worker(self: Arc<Self>) {
@ -214,10 +257,11 @@ impl BlockManager {
if needed_by_others { if needed_by_others {
let ring = garage.system.ring.borrow().clone(); let ring = garage.system.ring.borrow().clone();
let who = ring.walk_ring(&hash, garage.system.config.data_replication_factor); let who = ring.walk_ring(&hash, garage.system.config.data_replication_factor);
let msg = Message::NeedBlockQuery(hash.clone()); let msg = Arc::new(Message::NeedBlockQuery(hash.clone()));
let who_needs_fut = who let who_needs_fut = who.iter().map(|to| {
.iter() self.rpc_client
.map(|to| rpc_call(garage.system.clone(), to, &msg, NEED_BLOCK_QUERY_TIMEOUT)); .call(to, msg.clone(), NEED_BLOCK_QUERY_TIMEOUT)
});
let who_needs = join_all(who_needs_fut).await; let who_needs = join_all(who_needs_fut).await;
let mut need_nodes = vec![]; let mut need_nodes = vec![];
@ -242,12 +286,9 @@ impl BlockManager {
if need_nodes.len() > 0 { if need_nodes.len() > 0 {
let put_block_message = self.read_block(hash).await?; let put_block_message = self.read_block(hash).await?;
let put_responses = rpc_call_many( let put_responses = self
garage.system.clone(), .rpc_client
&need_nodes[..], .call_many(&need_nodes[..], put_block_message, BLOCK_RW_TIMEOUT)
put_block_message,
BLOCK_RW_TIMEOUT,
)
.await; .await;
for resp in put_responses { for resp in put_responses {
resp?; resp?;
@ -262,12 +303,48 @@ impl BlockManager {
// TODO find a way to not do this if they are sending it to us // TODO find a way to not do this if they are sending it to us
// Let's suppose this isn't an issue for now with the BLOCK_RW_TIMEOUT delay // Let's suppose this isn't an issue for now with the BLOCK_RW_TIMEOUT delay
// between the RC being incremented and this part being called. // between the RC being incremented and this part being called.
let block_data = rpc_get_block(&self.system, &hash).await?; let block_data = self.rpc_get_block(&hash).await?;
self.write_block(hash, &block_data[..]).await?; self.write_block(hash, &block_data[..]).await?;
} }
Ok(()) Ok(())
} }
pub async fn rpc_get_block(&self, hash: &Hash) -> Result<Vec<u8>, Error> {
let ring = self.system.ring.borrow().clone();
let who = ring.walk_ring(&hash, self.system.config.data_replication_factor);
let msg = Arc::new(Message::GetBlock(hash.clone()));
let mut resp_stream = who
.iter()
.map(|to| self.rpc_client.call(to, msg.clone(), BLOCK_RW_TIMEOUT))
.collect::<FuturesUnordered<_>>();
while let Some(resp) = resp_stream.next().await {
if let Ok(Message::PutBlock(msg)) = resp {
if data::hash(&msg.data[..]) == *hash {
return Ok(msg.data);
}
}
}
Err(Error::Message(format!(
"Unable to read block {:?}: no valid blocks returned",
hash
)))
}
pub async fn rpc_put_block(&self, hash: Hash, data: Vec<u8>) -> Result<(), Error> {
let ring = self.system.ring.borrow().clone();
let who = ring.walk_ring(&hash, self.system.config.data_replication_factor);
self.rpc_client
.try_call_many(
&who[..],
Message::PutBlock(PutBlockMessage { hash, data }),
(self.system.config.data_replication_factor + 1) / 2,
BLOCK_RW_TIMEOUT,
)
.await?;
Ok(())
}
} }
fn u64_from_bytes(bytes: &[u8]) -> u64 { fn u64_from_bytes(bytes: &[u8]) -> u64 {
@ -297,39 +374,3 @@ fn rc_merge(_key: &[u8], old: Option<&[u8]>, new: &[u8]) -> Option<Vec<u8>> {
Some(u64::to_be_bytes(new).to_vec()) Some(u64::to_be_bytes(new).to_vec())
} }
} }
pub async fn rpc_get_block(system: &Arc<System>, hash: &Hash) -> Result<Vec<u8>, Error> {
let ring = system.ring.borrow().clone();
let who = ring.walk_ring(&hash, system.config.data_replication_factor);
let msg = Message::GetBlock(hash.clone());
let mut resp_stream = who
.iter()
.map(|to| rpc_call(system.clone(), to, &msg, BLOCK_RW_TIMEOUT))
.collect::<FuturesUnordered<_>>();
while let Some(resp) = resp_stream.next().await {
if let Ok(Message::PutBlock(msg)) = resp {
if data::hash(&msg.data[..]) == *hash {
return Ok(msg.data);
}
}
}
Err(Error::Message(format!(
"Unable to read block {:?}: no valid blocks returned",
hash
)))
}
pub async fn rpc_put_block(system: &Arc<System>, hash: Hash, data: Vec<u8>) -> Result<(), Error> {
let ring = system.ring.borrow().clone();
let who = ring.walk_ring(&hash, system.config.data_replication_factor);
rpc_try_call_many(
system.clone(),
&who[..],
Message::PutBlock(PutBlockMessage { hash, data }),
(system.config.data_replication_factor + 1) / 2,
BLOCK_RW_TIMEOUT,
)
.await?;
Ok(())
}

View file

@ -22,12 +22,14 @@ mod tls_util;
use std::collections::HashSet; use std::collections::HashSet;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use structopt::StructOpt; use structopt::StructOpt;
use data::*; use data::*;
use error::Error; use error::Error;
use membership::Message;
use proto::*; use proto::*;
use rpc_client::RpcClient; use rpc_client::*;
use server::TlsConfig; use server::TlsConfig;
#[derive(StructOpt, Debug)] #[derive(StructOpt, Debug)]
@ -113,7 +115,9 @@ async fn main() {
} }
}; };
let rpc_cli = RpcClient::new(&tls_config).expect("Could not create RPC client"); let rpc_http_cli =
Arc::new(RpcHttpClient::new(&tls_config).expect("Could not create RPC client"));
let rpc_cli = RpcAddrClient::new(rpc_http_cli, "_membership".into());
let resp = match opt.cmd { let resp = match opt.cmd {
Command::Server(server_opt) => { Command::Server(server_opt) => {
@ -137,7 +141,7 @@ async fn main() {
} }
} }
async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Error> { async fn cmd_status(rpc_cli: RpcAddrClient<Message>, rpc_host: SocketAddr) -> Result<(), Error> {
let status = match rpc_cli let status = match rpc_cli
.call(&rpc_host, &Message::PullStatus, DEFAULT_TIMEOUT) .call(&rpc_host, &Message::PullStatus, DEFAULT_TIMEOUT)
.await? .await?
@ -196,7 +200,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro
} }
async fn cmd_configure( async fn cmd_configure(
rpc_cli: RpcClient, rpc_cli: RpcAddrClient<Message>,
rpc_host: SocketAddr, rpc_host: SocketAddr,
args: ConfigureOpt, args: ConfigureOpt,
) -> Result<(), Error> { ) -> Result<(), Error> {
@ -249,7 +253,7 @@ async fn cmd_configure(
} }
async fn cmd_remove( async fn cmd_remove(
rpc_cli: RpcClient, rpc_cli: RpcAddrClient<Message>,
rpc_host: SocketAddr, rpc_host: SocketAddr,
args: RemoveOpt, args: RemoveOpt,
) -> Result<(), Error> { ) -> Result<(), Error> {

View file

@ -10,6 +10,7 @@ use std::time::Duration;
use futures::future::join_all; use futures::future::join_all;
use futures::select; use futures::select;
use futures_util::future::*; use futures_util::future::*;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use tokio::prelude::*; use tokio::prelude::*;
use tokio::sync::watch; use tokio::sync::watch;
@ -20,17 +21,31 @@ use crate::data::*;
use crate::error::Error; use crate::error::Error;
use crate::proto::*; use crate::proto::*;
use crate::rpc_client::*; use crate::rpc_client::*;
use crate::rpc_server::*;
use crate::server::Config; use crate::server::Config;
const PING_INTERVAL: Duration = Duration::from_secs(10); const PING_INTERVAL: Duration = Duration::from_secs(10);
const PING_TIMEOUT: Duration = Duration::from_secs(2); const PING_TIMEOUT: Duration = Duration::from_secs(2);
const MAX_FAILED_PINGS: usize = 3; const MAX_FAILED_PINGS: usize = 3;
#[derive(Debug, Serialize, Deserialize)]
pub enum Message {
Ok,
Ping(PingMessage),
PullStatus,
PullConfig,
AdvertiseNodesUp(Vec<AdvertisedNode>),
AdvertiseConfig(NetworkConfig),
}
impl RpcMessage for Message {}
pub struct System { pub struct System {
pub config: Config, pub config: Config,
pub id: UUID, pub id: UUID,
pub rpc_client: RpcClient, pub rpc_http_client: Arc<RpcHttpClient>,
rpc_client: Arc<RpcClient<Message>>,
pub status: watch::Receiver<Arc<Status>>, pub status: watch::Receiver<Arc<Status>>,
pub ring: watch::Receiver<Arc<Ring>>, pub ring: watch::Receiver<Arc<Ring>>,
@ -199,7 +214,12 @@ fn read_network_config(metadata_dir: &PathBuf) -> Result<NetworkConfig, Error> {
} }
impl System { impl System {
pub fn new(config: Config, id: UUID, background: Arc<BackgroundRunner>) -> Self { pub fn new(
config: Config,
id: UUID,
background: Arc<BackgroundRunner>,
rpc_server: &mut RpcServer,
) -> Arc<Self> {
let net_config = match read_network_config(&config.metadata_dir) { let net_config = match read_network_config(&config.metadata_dir) {
Ok(x) => x, Ok(x) => x,
Err(e) => { Err(e) => {
@ -228,17 +248,54 @@ impl System {
ring.rebuild_ring(); ring.rebuild_ring();
let (update_ring, ring) = watch::channel(Arc::new(ring)); let (update_ring, ring) = watch::channel(Arc::new(ring));
let rpc_client = RpcClient::new(&config.rpc_tls).expect("Could not create RPC client"); let rpc_http_client =
Arc::new(RpcHttpClient::new(&config.rpc_tls).expect("Could not create RPC client"));
System { let rpc_path = "_membership";
let rpc_client = RpcClient::new(
RpcAddrClient::<Message>::new(rpc_http_client.clone(), rpc_path.into()),
background.clone(),
status.clone(),
);
let sys = Arc::new(System {
config, config,
id, id,
rpc_http_client,
rpc_client, rpc_client,
status, status,
ring, ring,
update_lock: Mutex::new((update_status, update_ring)), update_lock: Mutex::new((update_status, update_ring)),
background, background,
});
sys.clone().register_handler(rpc_server, rpc_path.into());
sys
} }
fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) {
rpc_server.add_handler::<Message, _, _>(path, move |msg, addr| {
let self2 = self.clone();
async move {
match msg {
Message::Ping(ping) => self2.handle_ping(&addr, &ping).await,
Message::PullStatus => self2.handle_pull_status(),
Message::PullConfig => self2.handle_pull_config(),
Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await,
Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await,
_ => Err(Error::Message(format!("Unexpected RPC message"))),
}
}
});
}
pub fn rpc_client<M: RpcMessage + 'static>(self: &Arc<Self>, path: &str) -> Arc<RpcClient<M>> {
RpcClient::new(
RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()),
self.background.clone(),
self.status.clone(),
)
} }
async fn save_network_config(self: Arc<Self>) -> Result<(), Error> { async fn save_network_config(self: Arc<Self>) -> Result<(), Error> {
@ -272,7 +329,7 @@ impl System {
.filter(|x| **x != self.id) .filter(|x| **x != self.id)
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
rpc_call_many(self.clone(), &to[..], msg, timeout).await; self.rpc_client.call_many(&to[..], msg, timeout).await;
} }
pub async fn bootstrap(self: Arc<Self>) { pub async fn bootstrap(self: Arc<Self>) {
@ -299,7 +356,10 @@ impl System {
( (
id_option, id_option,
addr.clone(), addr.clone(),
sys.rpc_client.call(&addr, ping_msg_ref, PING_TIMEOUT).await, sys.rpc_client
.by_addr()
.call(&addr, ping_msg_ref, PING_TIMEOUT)
.await,
) )
} }
})) }))
@ -509,7 +569,10 @@ impl System {
peer: UUID, peer: UUID,
) -> impl futures::future::Future<Output = ()> + Send + 'static { ) -> impl futures::future::Future<Output = ()> + Send + 'static {
async move { async move {
let resp = rpc_call(self.clone(), &peer, &Message::PullStatus, PING_TIMEOUT).await; let resp = self
.rpc_client
.call(&peer, Message::PullStatus, PING_TIMEOUT)
.await;
if let Ok(Message::AdvertiseNodesUp(nodes)) = resp { if let Ok(Message::AdvertiseNodesUp(nodes)) = resp {
let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await; let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await;
} }
@ -517,7 +580,10 @@ impl System {
} }
pub async fn pull_config(self: Arc<Self>, peer: UUID) { pub async fn pull_config(self: Arc<Self>, peer: UUID) {
let resp = rpc_call(self.clone(), &peer, &Message::PullConfig, PING_TIMEOUT).await; let resp = self
.rpc_client
.call(&peer, Message::PullConfig, PING_TIMEOUT)
.await;
if let Ok(Message::AdvertiseConfig(config)) = resp { if let Ok(Message::AdvertiseConfig(config)) = resp {
let _: Result<_, _> = self.handle_advertise_config(&config).await; let _: Result<_, _> = self.handle_advertise_config(&config).await;
} }

View file

@ -7,25 +7,6 @@ use crate::data::*;
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
pub const BLOCK_RW_TIMEOUT: Duration = Duration::from_secs(42); pub const BLOCK_RW_TIMEOUT: Duration = Duration::from_secs(42);
#[derive(Debug, Serialize, Deserialize)]
pub enum Message {
Ok,
Error(String),
Ping(PingMessage),
PullStatus,
PullConfig,
AdvertiseNodesUp(Vec<AdvertisedNode>),
AdvertiseConfig(NetworkConfig),
GetBlock(Hash),
PutBlock(PutBlockMessage),
NeedBlockQuery(Hash),
NeedBlockReply(bool),
TableRPC(String, #[serde(with = "serde_bytes")] Vec<u8>),
}
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct PingMessage { pub struct PingMessage {
pub id: UUID, pub id: UUID,

View file

@ -1,4 +1,5 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use std::marker::PhantomData;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -9,24 +10,66 @@ use futures::stream::StreamExt;
use futures_util::future::FutureExt; use futures_util::future::FutureExt;
use hyper::client::{Client, HttpConnector}; use hyper::client::{Client, HttpConnector};
use hyper::{Body, Method, Request, StatusCode}; use hyper::{Body, Method, Request, StatusCode};
use tokio::sync::watch;
use crate::background::*;
use crate::data::*; use crate::data::*;
use crate::error::Error; use crate::error::Error;
use crate::membership::System; use crate::membership::Status;
use crate::proto::Message; use crate::rpc_server::RpcMessage;
use crate::server::*; use crate::server::*;
use crate::tls_util; use crate::tls_util;
pub async fn rpc_call_many( pub struct RpcClient<M: RpcMessage> {
sys: Arc<System>, status: watch::Receiver<Arc<Status>>,
to: &[UUID], background: Arc<BackgroundRunner>,
msg: Message,
pub rpc_addr_client: RpcAddrClient<M>,
}
impl<M: RpcMessage + 'static> RpcClient<M> {
pub fn new(
rac: RpcAddrClient<M>,
background: Arc<BackgroundRunner>,
status: watch::Receiver<Arc<Status>>,
) -> Arc<Self> {
Arc::new(Self {
rpc_addr_client: rac,
background,
status,
})
}
pub fn by_addr(&self) -> &RpcAddrClient<M> {
&self.rpc_addr_client
}
pub async fn call<MB: Borrow<M>, N: Borrow<UUID>>(
&self,
to: N,
msg: MB,
timeout: Duration, timeout: Duration,
) -> Vec<Result<Message, Error>> { ) -> Result<M, Error> {
let addr = {
let status = self.status.borrow().clone();
match status.nodes.get(to.borrow()) {
Some(status) => status.addr.clone(),
None => {
return Err(Error::Message(format!(
"Peer ID not found: {:?}",
to.borrow()
)))
}
}
};
self.rpc_addr_client.call(&addr, msg, timeout).await
}
pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec<Result<M, Error>> {
let msg = Arc::new(msg); let msg = Arc::new(msg);
let mut resp_stream = to let mut resp_stream = to
.iter() .iter()
.map(|to| rpc_call(sys.clone(), to, msg.clone(), timeout)) .map(|to| self.call(to, msg.clone(), timeout))
.collect::<FuturesUnordered<_>>(); .collect::<FuturesUnordered<_>>();
let mut results = vec![]; let mut results = vec![];
@ -34,21 +77,24 @@ pub async fn rpc_call_many(
results.push(resp); results.push(resp);
} }
results results
} }
pub async fn rpc_try_call_many( pub async fn try_call_many(
sys: Arc<System>, self: &Arc<Self>,
to: &[UUID], to: &[UUID],
msg: Message, msg: M,
stop_after: usize, stop_after: usize,
timeout: Duration, timeout: Duration,
) -> Result<Vec<Message>, Error> { ) -> Result<Vec<M>, Error> {
let sys2 = sys.clone();
let msg = Arc::new(msg); let msg = Arc::new(msg);
let mut resp_stream = to let mut resp_stream = to
.to_vec() .to_vec()
.into_iter() .into_iter()
.map(move |to| rpc_call(sys2.clone(), to.clone(), msg.clone(), timeout)) .map(|to| {
let self2 = self.clone();
let msg = msg.clone();
async move { self2.call(to.clone(), msg, timeout).await }
})
.collect::<FuturesUnordered<_>>(); .collect::<FuturesUnordered<_>>();
let mut results = vec![]; let mut results = vec![];
@ -71,7 +117,7 @@ pub async fn rpc_try_call_many(
if results.len() >= stop_after { if results.len() >= stop_after {
// Continue requests in background // Continue requests in background
// TODO: make this optionnal (only usefull for write requests) // TODO: make this optionnal (only usefull for write requests)
sys.background.spawn(async move { self.clone().background.spawn(async move {
resp_stream.collect::<Vec<_>>().await; resp_stream.collect::<Vec<_>>().await;
Ok(()) Ok(())
}); });
@ -84,35 +130,46 @@ pub async fn rpc_try_call_many(
} }
Err(Error::Message(msg)) Err(Error::Message(msg))
} }
}
} }
pub async fn rpc_call<M: Borrow<Message>, N: Borrow<UUID>>( pub struct RpcAddrClient<M: RpcMessage> {
sys: Arc<System>, phantom: PhantomData<M>,
to: N,
msg: M, pub http_client: Arc<RpcHttpClient>,
pub path: String,
}
impl<M: RpcMessage> RpcAddrClient<M> {
pub fn new(http_client: Arc<RpcHttpClient>, path: String) -> Self {
Self {
phantom: PhantomData::default(),
http_client: http_client,
path,
}
}
pub async fn call<MB>(
&self,
to_addr: &SocketAddr,
msg: MB,
timeout: Duration, timeout: Duration,
) -> Result<Message, Error> { ) -> Result<M, Error>
let addr = { where
let status = sys.status.borrow().clone(); MB: Borrow<M>,
match status.nodes.get(to.borrow()) { {
Some(status) => status.addr.clone(), self.http_client
None => { .call(&self.path, to_addr, msg, timeout)
return Err(Error::Message(format!( .await
"Peer ID not found: {:?}",
to.borrow()
)))
} }
}
};
sys.rpc_client.call(&addr, msg, timeout).await
} }
pub enum RpcClient { pub enum RpcHttpClient {
HTTP(Client<HttpConnector, hyper::Body>), HTTP(Client<HttpConnector, hyper::Body>),
HTTPS(Client<tls_util::HttpsConnectorFixedDnsname<HttpConnector>, hyper::Body>), HTTPS(Client<tls_util::HttpsConnectorFixedDnsname<HttpConnector>, hyper::Body>),
} }
impl RpcClient { impl RpcHttpClient {
pub fn new(tls_config: &Option<TlsConfig>) -> Result<Self, Error> { pub fn new(tls_config: &Option<TlsConfig>) -> Result<Self, Error> {
if let Some(cf) = tls_config { if let Some(cf) = tls_config {
let ca_certs = tls_util::load_certs(&cf.ca_cert)?; let ca_certs = tls_util::load_certs(&cf.ca_cert)?;
@ -130,21 +187,26 @@ impl RpcClient {
let connector = let connector =
tls_util::HttpsConnectorFixedDnsname::<HttpConnector>::new(config, "garage"); tls_util::HttpsConnectorFixedDnsname::<HttpConnector>::new(config, "garage");
Ok(RpcClient::HTTPS(Client::builder().build(connector))) Ok(RpcHttpClient::HTTPS(Client::builder().build(connector)))
} else { } else {
Ok(RpcClient::HTTP(Client::new())) Ok(RpcHttpClient::HTTP(Client::new()))
} }
} }
pub async fn call<M: Borrow<Message>>( async fn call<M, MB>(
&self, &self,
path: &str,
to_addr: &SocketAddr, to_addr: &SocketAddr,
msg: M, msg: MB,
timeout: Duration, timeout: Duration,
) -> Result<Message, Error> { ) -> Result<M, Error>
where
MB: Borrow<M>,
M: RpcMessage,
{
let uri = match self { let uri = match self {
RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr), RpcHttpClient::HTTP(_) => format!("http://{}/{}", to_addr, path),
RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr), RpcHttpClient::HTTPS(_) => format!("https://{}/{}", to_addr, path),
}; };
let req = Request::builder() let req = Request::builder()
@ -153,8 +215,8 @@ impl RpcClient {
.body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?;
let resp_fut = match self { let resp_fut = match self {
RpcClient::HTTP(client) => client.request(req).fuse(), RpcHttpClient::HTTP(client) => client.request(req).fuse(),
RpcClient::HTTPS(client) => client.request(req).fuse(), RpcHttpClient::HTTPS(client) => client.request(req).fuse(),
}; };
let resp = tokio::time::timeout(timeout, resp_fut) let resp = tokio::time::timeout(timeout, resp_fut)
.await? .await?
@ -168,11 +230,8 @@ impl RpcClient {
if resp.status() == StatusCode::OK { if resp.status() == StatusCode::OK {
let body = hyper::body::to_bytes(resp.into_body()).await?; let body = hyper::body::to_bytes(resp.into_body()).await?;
let msg = rmp_serde::decode::from_read::<_, Message>(body.into_buf())?; let msg = rmp_serde::decode::from_read::<_, Result<M, String>>(body.into_buf())?;
match msg { msg.map_err(Error::RPCError)
Message::Error(e) => Err(Error::RPCError(e)),
x => Ok(x),
}
} else { } else {
Err(Error::RPCError(format!("Status code {}", resp.status()))) Err(Error::RPCError(format!("Status code {}", resp.status())))
} }

View file

@ -1,4 +1,6 @@
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use bytes::IntoBuf; use bytes::IntoBuf;
@ -8,100 +10,121 @@ use futures_util::stream::*;
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn}; use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode}; use hyper::{Body, Method, Request, Response, Server, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::server::TlsStream; use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::data::*; use crate::data::*;
use crate::error::Error; use crate::error::Error;
use crate::proto::Message; use crate::server::TlsConfig;
use crate::server::Garage;
use crate::tls_util; use crate::tls_util;
fn err_to_msg(x: Result<Message, Error>) -> Message { pub trait RpcMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {}
match x {
Err(e) => Message::Error(format!("{}", e)), type ResponseFuture = Pin<Box<dyn Future<Output = Result<Response<Body>, Error>> + Send>>;
Ok(msg) => msg, type Handler = Box<dyn Fn(Request<Body>, SocketAddr) -> ResponseFuture + Send + Sync>;
pub struct RpcServer {
pub bind_addr: SocketAddr,
pub tls_config: Option<TlsConfig>,
handlers: HashMap<String, Handler>,
}
async fn handle_func<M, F, Fut>(
handler: Arc<F>,
req: Request<Body>,
sockaddr: SocketAddr,
) -> Result<Response<Body>, Error>
where
M: RpcMessage + 'static,
F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<M, Error>> + Send + 'static,
{
let whole_body = hyper::body::to_bytes(req.into_body()).await?;
let msg = rmp_serde::decode::from_read::<_, M>(whole_body.into_buf())?;
match handler(msg, sockaddr).await {
Ok(resp) => {
let resp_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Ok(resp))?;
Ok(Response::new(Body::from(resp_bytes)))
}
Err(e) => {
let err_str = format!("{}", e);
let rep_bytes = rmp_to_vec_all_named::<Result<M, String>>(&Err(err_str))?;
let mut err_response = Response::new(Body::from(rep_bytes));
*err_response.status_mut() = e.http_status_code();
Ok(err_response)
}
} }
} }
async fn handler( impl RpcServer {
garage: Arc<Garage>, pub fn new(bind_addr: SocketAddr, tls_config: Option<TlsConfig>) -> Self {
Self {
bind_addr,
tls_config,
handlers: HashMap::new(),
}
}
pub fn add_handler<M, F, Fut>(&mut self, name: String, handler: F)
where
M: RpcMessage + 'static,
F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<M, Error>> + Send + 'static,
{
let handler_arc = Arc::new(handler);
let handler = Box::new(move |req: Request<Body>, sockaddr: SocketAddr| {
let handler2 = handler_arc.clone();
let b: ResponseFuture = Box::pin(handle_func(handler2, req, sockaddr));
b
});
self.handlers.insert(name, handler);
}
async fn handler(
self: Arc<Self>,
req: Request<Body>, req: Request<Body>,
addr: SocketAddr, addr: SocketAddr,
) -> Result<Response<Body>, Error> { ) -> Result<Response<Body>, Error> {
if req.method() != &Method::POST { if req.method() != &Method::POST {
let mut bad_request = Response::default(); let mut bad_request = Response::default();
*bad_request.status_mut() = StatusCode::BAD_REQUEST; *bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request); return Ok(bad_request);
} }
let whole_body = hyper::body::to_bytes(req.into_body()).await?; let path = &req.uri().path()[1..];
let msg = rmp_serde::decode::from_read::<_, Message>(whole_body.into_buf())?; let handler = match self.handlers.get(path) {
Some(h) => h,
// eprintln!( None => {
// "RPC from {}: {} ({} bytes)", let mut not_found = Response::default();
// addr, *not_found.status_mut() = StatusCode::NOT_FOUND;
// debug_serialize(&msg), return Ok(not_found);
// whole_body.len()
// );
let sys = garage.system.clone();
let resp = err_to_msg(match msg {
Message::Ping(ping) => sys.handle_ping(&addr, &ping).await,
Message::PullStatus => sys.handle_pull_status(),
Message::PullConfig => sys.handle_pull_config(),
Message::AdvertiseNodesUp(adv) => sys.handle_advertise_nodes_up(&adv).await,
Message::AdvertiseConfig(adv) => sys.handle_advertise_config(&adv).await,
Message::PutBlock(m) => {
// A RPC can be interrupted in the middle, however we don't want to write partial blocks,
// which might happen if the write_block() future is cancelled in the middle.
// To solve this, the write itself is in a spawned task that has its own separate lifetime,
// and the request handler simply sits there waiting for the task to finish.
// (if it's cancelled, that's not an issue)
// (TODO FIXME except if garage happens to shut down at that point)
let write_fut = async move { garage.block_manager.write_block(&m.hash, &m.data).await };
tokio::spawn(write_fut).await?
}
Message::GetBlock(h) => garage.block_manager.read_block(&h).await,
Message::NeedBlockQuery(h) => garage
.block_manager
.need_block(&h)
.await
.map(Message::NeedBlockReply),
Message::TableRPC(table, msg) => {
// Same trick for table RPCs than for PutBlock
let op_fut = async move {
if let Some(rpc_handler) = garage.table_rpc_handlers.get(&table) {
rpc_handler
.handle(&msg[..])
.await
.map(|rep| Message::TableRPC(table.to_string(), rep))
} else {
Ok(Message::Error(format!("Unknown table: {}", table)))
} }
}; };
tokio::spawn(op_fut).await?
let resp_waiter = tokio::spawn(handler(req, addr));
match resp_waiter.await {
Err(_err) => {
let mut ise = Response::default();
*ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(ise)
}
Ok(Err(err)) => {
let mut bad_request = Response::new(Body::from(format!("{}", err)));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
Ok(bad_request)
}
Ok(Ok(resp)) => Ok(resp),
}
} }
_ => Ok(Message::Error(format!("Unexpected message: {:?}", msg))), pub async fn run(
}); self: Arc<Self>,
// eprintln!("reply to {}: {}", addr, debug_serialize(&resp));
Ok(Response::new(Body::from(rmp_to_vec_all_named(&resp)?)))
}
pub async fn run_rpc_server(
garage: Arc<Garage>,
shutdown_signal: impl Future<Output = ()>, shutdown_signal: impl Future<Output = ()>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], garage.system.config.rpc_port).into(); if let Some(tls_config) = self.tls_config.as_ref() {
if let Some(tls_config) = &garage.system.config.rpc_tls {
let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?;
let node_certs = tls_util::load_certs(&tls_config.node_cert)?; let node_certs = tls_util::load_certs(&tls_config.node_cert)?;
let node_key = tls_util::load_private_key(&tls_config.node_key)?; let node_key = tls_util::load_private_key(&tls_config.node_key)?;
@ -116,7 +139,7 @@ pub async fn run_rpc_server(
config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?;
let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config)));
let mut listener = TcpListener::bind(&bind_addr).await?; let mut listener = TcpListener::bind(&self.bind_addr).await?;
let incoming = listener.incoming().filter_map(|socket| async { let incoming = listener.incoming().filter_map(|socket| async {
match socket { match socket {
Ok(stream) => match tls_acceptor.clone().accept(stream).await { Ok(stream) => match tls_acceptor.clone().accept(stream).await {
@ -131,17 +154,17 @@ pub async fn run_rpc_server(
}); });
let incoming = hyper::server::accept::from_stream(incoming); let incoming = hyper::server::accept::from_stream(incoming);
let self_arc = self.clone();
let service = make_service_fn(|conn: &TlsStream<TcpStream>| { let service = make_service_fn(|conn: &TlsStream<TcpStream>| {
let client_addr = conn let client_addr = conn
.get_ref() .get_ref()
.0 .0
.peer_addr() .peer_addr()
.unwrap_or(([0, 0, 0, 0], 0).into()); .unwrap_or(([0, 0, 0, 0], 0).into());
let garage = garage.clone(); let self_arc = self_arc.clone();
async move { async move {
Ok::<_, Error>(service_fn(move |req: Request<Body>| { Ok::<_, Error>(service_fn(move |req: Request<Body>| {
let garage = garage.clone(); self_arc.clone().handler(req, client_addr).map_err(|e| {
handler(garage, req, client_addr).map_err(|e| {
eprintln!("RPC handler error: {}", e); eprintln!("RPC handler error: {}", e);
e e
}) })
@ -152,17 +175,17 @@ pub async fn run_rpc_server(
let server = Server::builder(incoming).serve(service); let server = Server::builder(incoming).serve(service);
let graceful = server.with_graceful_shutdown(shutdown_signal); let graceful = server.with_graceful_shutdown(shutdown_signal);
println!("RPC server listening on http://{}", bind_addr); println!("RPC server listening on http://{}", self.bind_addr);
graceful.await?; graceful.await?;
} else { } else {
let service = make_service_fn(|conn: &AddrStream| { let self_arc = self.clone();
let service = make_service_fn(move |conn: &AddrStream| {
let client_addr = conn.remote_addr(); let client_addr = conn.remote_addr();
let garage = garage.clone(); let self_arc = self_arc.clone();
async move { async move {
Ok::<_, Error>(service_fn(move |req: Request<Body>| { Ok::<_, Error>(service_fn(move |req: Request<Body>| {
let garage = garage.clone(); self_arc.clone().handler(req, client_addr).map_err(|e| {
handler(garage, req, client_addr).map_err(|e| {
eprintln!("RPC handler error: {}", e); eprintln!("RPC handler error: {}", e);
e e
}) })
@ -170,13 +193,14 @@ pub async fn run_rpc_server(
} }
}); });
let server = Server::bind(&bind_addr).serve(service); let server = Server::bind(&self.bind_addr).serve(service);
let graceful = server.with_graceful_shutdown(shutdown_signal); let graceful = server.with_graceful_shutdown(shutdown_signal);
println!("RPC server listening on http://{}", bind_addr); println!("RPC server listening on http://{}", self.bind_addr);
graceful.await?; graceful.await?;
} }
Ok(()) Ok(())
}
} }

View file

@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
@ -15,7 +14,7 @@ use crate::data::*;
use crate::error::Error; use crate::error::Error;
use crate::membership::System; use crate::membership::System;
use crate::proto::*; use crate::proto::*;
use crate::rpc_server; use crate::rpc_server::RpcServer;
use crate::table::*; use crate::table::*;
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
@ -53,8 +52,6 @@ pub struct Garage {
pub system: Arc<System>, pub system: Arc<System>,
pub block_manager: Arc<BlockManager>, pub block_manager: Arc<BlockManager>,
pub table_rpc_handlers: HashMap<String, Box<dyn TableRpcHandler + Sync + Send>>,
pub object_table: Arc<Table<ObjectTable>>, pub object_table: Arc<Table<ObjectTable>>,
pub version_table: Arc<Table<VersionTable>>, pub version_table: Arc<Table<VersionTable>>,
pub block_ref_table: Arc<Table<BlockRefTable>>, pub block_ref_table: Arc<Table<BlockRefTable>>,
@ -66,12 +63,14 @@ impl Garage {
id: UUID, id: UUID,
db: sled::Db, db: sled::Db,
background: Arc<BackgroundRunner>, background: Arc<BackgroundRunner>,
rpc_server: &mut RpcServer,
) -> Arc<Self> { ) -> Arc<Self> {
println!("Initialize membership management system..."); println!("Initialize membership management system...");
let system = Arc::new(System::new(config.clone(), id, background.clone())); let system = System::new(config.clone(), id, background.clone(), rpc_server);
println!("Initialize block manager..."); println!("Initialize block manager...");
let block_manager = BlockManager::new(&db, config.data_dir.clone(), system.clone()); let block_manager =
BlockManager::new(&db, config.data_dir.clone(), system.clone(), rpc_server);
let data_rep_param = TableReplicationParams { let data_rep_param = TableReplicationParams {
replication_factor: system.config.data_replication_factor, replication_factor: system.config.data_replication_factor,
@ -97,6 +96,7 @@ impl Garage {
&db, &db,
"block_ref".to_string(), "block_ref".to_string(),
data_rep_param.clone(), data_rep_param.clone(),
rpc_server,
) )
.await; .await;
@ -110,6 +110,7 @@ impl Garage {
&db, &db,
"version".to_string(), "version".to_string(),
meta_rep_param.clone(), meta_rep_param.clone(),
rpc_server,
) )
.await; .await;
@ -123,35 +124,20 @@ impl Garage {
&db, &db,
"object".to_string(), "object".to_string(),
meta_rep_param.clone(), meta_rep_param.clone(),
rpc_server,
) )
.await; .await;
println!("Initialize Garage..."); println!("Initialize Garage...");
let mut garage = Self { let garage = Arc::new(Self {
db, db,
system: system.clone(), system: system.clone(),
block_manager, block_manager,
background, background,
table_rpc_handlers: HashMap::new(),
object_table, object_table,
version_table, version_table,
block_ref_table, block_ref_table,
}; });
garage.table_rpc_handlers.insert(
garage.object_table.name.clone(),
garage.object_table.clone().rpc_handler(),
);
garage.table_rpc_handlers.insert(
garage.version_table.name.clone(),
garage.version_table.clone().rpc_handler(),
);
garage.table_rpc_handlers.insert(
garage.block_ref_table.name.clone(),
garage.block_ref_table.clone().rpc_handler(),
);
let garage = Arc::new(garage);
println!("Start block manager background thread..."); println!("Start block manager background thread...");
garage.block_manager.garage.swap(Some(garage.clone())); garage.block_manager.garage.swap(Some(garage.clone()));
@ -232,20 +218,23 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> {
db_path.push("db"); db_path.push("db");
let db = sled::open(db_path).expect("Unable to open DB"); let db = sled::open(db_path).expect("Unable to open DB");
let (send_cancel, watch_cancel) = watch::channel(false); println!("Initialize RPC server...");
let rpc_bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], config.rpc_port).into();
let mut rpc_server = RpcServer::new(rpc_bind_addr, config.rpc_tls.clone());
println!("Initializing background runner..."); println!("Initializing background runner...");
let (send_cancel, watch_cancel) = watch::channel(false);
let background = BackgroundRunner::new(8, watch_cancel.clone()); let background = BackgroundRunner::new(8, watch_cancel.clone());
let garage = Garage::new(config, id, db, background.clone()).await; let garage = Garage::new(config, id, db, background.clone(), &mut rpc_server).await;
println!("Initializing RPC and API servers..."); println!("Initializing RPC and API servers...");
let rpc_server = rpc_server::run_rpc_server(garage.clone(), wait_from(watch_cancel.clone())); let run_rpc_server = Arc::new(rpc_server).run(wait_from(watch_cancel.clone()));
let api_server = api_server::run_api_server(garage.clone(), wait_from(watch_cancel.clone())); let api_server = api_server::run_api_server(garage.clone(), wait_from(watch_cancel.clone()));
futures::try_join!( futures::try_join!(
garage.system.clone().bootstrap().map(Ok), garage.system.clone().bootstrap().map(Ok),
rpc_server, run_rpc_server,
api_server, api_server,
background.run().map(Ok), background.run().map(Ok),
shutdown_signal(send_cancel), shutdown_signal(send_cancel),

View file

@ -11,14 +11,15 @@ use serde_bytes::ByteBuf;
use crate::data::*; use crate::data::*;
use crate::error::Error; use crate::error::Error;
use crate::membership::System; use crate::membership::System;
use crate::proto::*;
use crate::rpc_client::*; use crate::rpc_client::*;
use crate::rpc_server::*;
use crate::table_sync::*; use crate::table_sync::*;
pub struct Table<F: TableSchema> { pub struct Table<F: TableSchema> {
pub instance: F, pub instance: F,
pub name: String, pub name: String,
pub rpc_client: Arc<RpcClient<TableRPC<F>>>,
pub system: Arc<System>, pub system: Arc<System>,
pub store: sled::Tree, pub store: sled::Tree,
@ -35,24 +36,6 @@ pub struct TableReplicationParams {
pub timeout: Duration, pub timeout: Duration,
} }
#[async_trait]
pub trait TableRpcHandler {
async fn handle(&self, rpc: &[u8]) -> Result<Vec<u8>, Error>;
}
struct TableRpcHandlerAdapter<F: TableSchema> {
table: Arc<Table<F>>,
}
#[async_trait]
impl<F: TableSchema + 'static> TableRpcHandler for TableRpcHandlerAdapter<F> {
async fn handle(&self, rpc: &[u8]) -> Result<Vec<u8>, Error> {
let msg = rmp_serde::decode::from_read_ref::<_, TableRPC<F>>(rpc)?;
let rep = self.table.handle(msg).await?;
Ok(rmp_to_vec_all_named(&rep)?)
}
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub enum TableRPC<F: TableSchema> { pub enum TableRPC<F: TableSchema> {
Ok, Ok,
@ -67,6 +50,8 @@ pub enum TableRPC<F: TableSchema> {
SyncRPC(SyncRPC), SyncRPC(SyncRPC),
} }
impl<F: TableSchema> RpcMessage for TableRPC<F> {}
pub trait PartitionKey { pub trait PartitionKey {
fn hash(&self) -> Hash; fn hash(&self) -> Hash;
} }
@ -136,18 +121,27 @@ impl<F: TableSchema + 'static> Table<F> {
db: &sled::Db, db: &sled::Db,
name: String, name: String,
param: TableReplicationParams, param: TableReplicationParams,
rpc_server: &mut RpcServer,
) -> Arc<Self> { ) -> Arc<Self> {
let store = db.open_tree(&name).expect("Unable to open DB tree"); let store = db.open_tree(&name).expect("Unable to open DB tree");
let rpc_path = format!("table_{}", name);
let rpc_client = system.rpc_client::<TableRPC<F>>(&rpc_path);
let table = Arc::new(Self { let table = Arc::new(Self {
instance, instance,
name, name,
rpc_client,
system, system,
store, store,
param, param,
syncer: ArcSwapOption::from(None), syncer: ArcSwapOption::from(None),
}); });
table.clone().register_handler(rpc_server, rpc_path);
let syncer = TableSyncer::launch(table.clone()).await; let syncer = TableSyncer::launch(table.clone()).await;
table.syncer.swap(Some(syncer)); table.syncer.swap(Some(syncer));
table table
} }
@ -158,9 +152,10 @@ impl<F: TableSchema + 'static> Table<F> {
//eprintln!("insert who: {:?}", who); //eprintln!("insert who: {:?}", who);
let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(e)?)); let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(e)?));
let rpc = &TableRPC::<F>::Update(vec![e_enc]); let rpc = TableRPC::<F>::Update(vec![e_enc]);
self.rpc_try_call_many(&who[..], &rpc, self.param.write_quorum) self.rpc_client
.try_call_many(&who[..], rpc, self.param.write_quorum, self.param.timeout)
.await?; .await?;
Ok(()) Ok(())
} }
@ -183,10 +178,8 @@ impl<F: TableSchema + 'static> Table<F> {
let call_futures = call_list.drain().map(|(node, entries)| async move { let call_futures = call_list.drain().map(|(node, entries)| async move {
let rpc = TableRPC::<F>::Update(entries); let rpc = TableRPC::<F>::Update(entries);
let rpc_bytes = rmp_to_vec_all_named(&rpc)?;
let message = Message::TableRPC(self.name.to_string(), rpc_bytes);
let resp = rpc_call(self.system.clone(), &node, &message, self.param.timeout).await?; let resp = self.rpc_client.call(&node, rpc, self.param.timeout).await?;
Ok::<_, Error>((node, resp)) Ok::<_, Error>((node, resp))
}); });
let mut resps = call_futures.collect::<FuturesUnordered<_>>(); let mut resps = call_futures.collect::<FuturesUnordered<_>>();
@ -214,9 +207,10 @@ impl<F: TableSchema + 'static> Table<F> {
let who = ring.walk_ring(&hash, self.param.replication_factor); let who = ring.walk_ring(&hash, self.param.replication_factor);
//eprintln!("get who: {:?}", who); //eprintln!("get who: {:?}", who);
let rpc = &TableRPC::<F>::ReadEntry(partition_key.clone(), sort_key.clone()); let rpc = TableRPC::<F>::ReadEntry(partition_key.clone(), sort_key.clone());
let resps = self let resps = self
.rpc_try_call_many(&who[..], &rpc, self.param.read_quorum) .rpc_client
.try_call_many(&who[..], rpc, self.param.read_quorum, self.param.timeout)
.await?; .await?;
let mut ret = None; let mut ret = None;
@ -264,9 +258,10 @@ impl<F: TableSchema + 'static> Table<F> {
let who = ring.walk_ring(&hash, self.param.replication_factor); let who = ring.walk_ring(&hash, self.param.replication_factor);
let rpc = let rpc =
&TableRPC::<F>::ReadRange(partition_key.clone(), begin_sort_key.clone(), filter, limit); TableRPC::<F>::ReadRange(partition_key.clone(), begin_sort_key.clone(), filter, limit);
let resps = self let resps = self
.rpc_try_call_many(&who[..], &rpc, self.param.read_quorum) .rpc_client
.try_call_many(&who[..], rpc, self.param.read_quorum, self.param.timeout)
.await?; .await?;
let mut ret = BTreeMap::new(); let mut ret = BTreeMap::new();
@ -315,71 +310,24 @@ impl<F: TableSchema + 'static> Table<F> {
async fn repair_on_read(&self, who: &[UUID], what: F::E) -> Result<(), Error> { async fn repair_on_read(&self, who: &[UUID], what: F::E) -> Result<(), Error> {
let what_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(&what)?)); let what_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(&what)?));
self.rpc_try_call_many(&who[..], &TableRPC::<F>::Update(vec![what_enc]), who.len()) self.rpc_client
.try_call_many(
&who[..],
TableRPC::<F>::Update(vec![what_enc]),
who.len(),
self.param.timeout,
)
.await?; .await?;
Ok(()) Ok(())
} }
async fn rpc_try_call_many(
&self,
who: &[UUID],
rpc: &TableRPC<F>,
quorum: usize,
) -> Result<Vec<TableRPC<F>>, Error> {
//eprintln!("Table RPC to {:?}: {}", who, serde_json::to_string(&rpc)?);
let rpc_bytes = rmp_to_vec_all_named(rpc)?;
let rpc_msg = Message::TableRPC(self.name.to_string(), rpc_bytes);
let resps = rpc_try_call_many(
self.system.clone(),
who,
rpc_msg,
quorum,
self.param.timeout,
)
.await?;
let mut resps_vals = vec![];
for resp in resps {
if let Message::TableRPC(tbl, rep_by) = &resp {
if *tbl == self.name {
resps_vals.push(rmp_serde::decode::from_read_ref(&rep_by)?);
continue;
}
}
return Err(Error::Message(format!(
"Invalid reply to TableRPC: {:?}",
resp
)));
}
//eprintln!(
// "Table RPC responses: {}",
// serde_json::to_string(&resps_vals)?
//);
Ok(resps_vals)
}
pub async fn rpc_call(&self, who: &UUID, rpc: &TableRPC<F>) -> Result<TableRPC<F>, Error> {
let rpc_bytes = rmp_to_vec_all_named(rpc)?;
let rpc_msg = Message::TableRPC(self.name.to_string(), rpc_bytes);
let resp = rpc_call(self.system.clone(), who, &rpc_msg, self.param.timeout).await?;
if let Message::TableRPC(tbl, rep_by) = &resp {
if *tbl == self.name {
return Ok(rmp_serde::decode::from_read_ref(&rep_by)?);
}
}
Err(Error::Message(format!(
"Invalid reply to TableRPC: {:?}",
resp
)))
}
// =============== HANDLERS FOR RPC OPERATIONS (SERVER SIDE) ============== // =============== HANDLERS FOR RPC OPERATIONS (SERVER SIDE) ==============
pub fn rpc_handler(self: Arc<Self>) -> Box<dyn TableRpcHandler + Send + Sync> { fn register_handler(self: Arc<Self>, rpc_server: &mut RpcServer, path: String) {
Box::new(TableRpcHandlerAdapter::<F> { table: self }) rpc_server.add_handler::<TableRPC<F>, _, _>(path, move |msg, _addr| {
let self2 = self.clone();
async move { self2.handle(msg).await }
})
} }
async fn handle(self: &Arc<Self>, msg: TableRPC<F>) -> Result<TableRPC<F>, Error> { async fn handle(self: &Arc<Self>, msg: TableRPC<F>) -> Result<TableRPC<F>, Error> {

View file

@ -360,12 +360,14 @@ impl<F: TableSchema + 'static> TableSyncer<F> {
// If their root checksum has level > than us, use that as a reference // If their root checksum has level > than us, use that as a reference
let root_cks_resp = self let root_cks_resp = self
.table .table
.rpc_call( .rpc_client
.call(
&who, &who,
&TableRPC::<F>::SyncRPC(SyncRPC::GetRootChecksumRange( &TableRPC::<F>::SyncRPC(SyncRPC::GetRootChecksumRange(
partition.begin.clone(), partition.begin.clone(),
partition.end.clone(), partition.end.clone(),
)), )),
self.table.param.timeout,
) )
.await?; .await?;
if let TableRPC::<F>::SyncRPC(SyncRPC::RootChecksumRange(range)) = root_cks_resp { if let TableRPC::<F>::SyncRPC(SyncRPC::RootChecksumRange(range)) = root_cks_resp {
@ -392,9 +394,11 @@ impl<F: TableSchema + 'static> TableSyncer<F> {
let rpc_resp = self let rpc_resp = self
.table .table
.rpc_call( .rpc_client
.call(
&who, &who,
&TableRPC::<F>::SyncRPC(SyncRPC::Checksums(step, retain)), &TableRPC::<F>::SyncRPC(SyncRPC::Checksums(step, retain)),
self.table.param.timeout,
) )
.await?; .await?;
if let TableRPC::<F>::SyncRPC(SyncRPC::Difference(mut diff_ranges, diff_items)) = if let TableRPC::<F>::SyncRPC(SyncRPC::Difference(mut diff_ranges, diff_items)) =
@ -451,7 +455,12 @@ impl<F: TableSchema + 'static> TableSyncer<F> {
} }
let rpc_resp = self let rpc_resp = self
.table .table
.rpc_call(&who, &TableRPC::<F>::Update(values)) .rpc_client
.call(
&who,
&TableRPC::<F>::Update(values),
self.table.param.timeout,
)
.await?; .await?;
if let TableRPC::<F>::Ok = rpc_resp { if let TableRPC::<F>::Ok = rpc_resp {
Ok(()) Ok(())