Compare commits

..

No commits in common. "4a6e8c4c73dcc9dd99e875438f6c7a2d77822bda" and "eceafd9de8fe89fbdd49bcfed0553c22a4e5869d" have entirely different histories.

4 changed files with 28 additions and 62 deletions

@ -94,15 +94,14 @@ pub struct State {
my_ip: String,
nodes: RwLock<HashMap<IP, NodeInfo>>,
conns: RwLock<HashSet<IP>>,
timeout: u64,
}
impl State {
pub fn new(my_ip: String, timeout: u64) -> Self {
pub fn new(my_ip: String) -> Self {
let mut nodes = HashMap::new();
let my_info = NodeInfo::load();
nodes.insert(my_ip.clone(), my_info);
Self { my_ip, timeout, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
}
pub async fn add_conn(&self, ip: &str) {
@ -199,10 +198,6 @@ impl State {
self.my_ip.clone()
}
pub fn get_timeout(&self) -> Duration {
Duration::from_secs(self.timeout)
}
pub async fn get_my_info(&self) -> NodeInfo {
let nodes = self.nodes.read().await;
nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty())
@ -210,7 +205,7 @@ impl State {
pub async fn get_connected_ips(&self) -> Vec<String> {
let conns = self.conns.read().await;
conns.iter().cloned().collect()
conns.iter().map(|n| n.clone()).collect()
}
// returns a random node that does not have an active connection
@ -219,10 +214,7 @@ impl State {
let conn_ips = self.get_connected_ips().await;
let nodes = self.nodes.read().await;
if nodes.is_empty() {
return None;
}
let skip = OsRng.next_u64() as usize % nodes.len();
let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len();
nodes
.keys()
.map(|ip| ip.to_string())
@ -231,15 +223,6 @@ impl State {
.nth(skip)
}
pub async fn update_keepalive(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.keepalive = SystemTime::now();
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
/// This returns true if the update should be further forwarded
/// For example, we never forward our own updates that came back
pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool {
@ -256,15 +239,13 @@ impl State {
is_update_new && !is_update_mine
}
pub async fn remove_inactive_nodes(&self) {
pub async fn remove_inactive_nodes(&self, max_age: u64) {
let mut nodes = self.nodes.write().await;
// TODO: Check if it is possible to corrupt SGX system time
// TODO: Double check if we need to cleanup nodes
let now = SystemTime::now();
let max_age = self.timeout;
nodes.retain(|_, n| {
let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs();
age < max_age
age <= max_age
});
}
}

@ -104,26 +104,19 @@ impl ConnManager {
}
});
let to_stream = match client.get_updates(rx_stream).await {
Ok(response) => {
let stream = response.into_inner();
stream.timeout(self.state.get_timeout())
}
Err(e) => {
let response = client.get_updates(rx_stream).await;
if let Err(e) = response {
println!("Error connecting to {remote_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") {
self.state.increase_net_attacks().await;
}
return Err(e.into());
}
};
let mut resp_stream = response.unwrap().into_inner();
// TODO: Check if immediately sending our info as a network update to everybody works
//let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
tokio::pin!(to_stream);
let mut updates = to_stream.take_while(Result::is_ok);
while let Some(Ok(Ok(update))) = updates.next().await {
// Immediately send our info as a network update
let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
while let Some(update) = resp_stream.message().await? {
// Update the entire network in case the information is new
if self.state.process_node_update(update.clone().into()).await {
// If process update returns true, the update must be forwarded
@ -133,7 +126,7 @@ impl ConnManager {
}
}
Err(std::io::Error::new(std::io::ErrorKind::Other, "Updates interrupted").into())
Ok(())
}
}

@ -7,7 +7,6 @@ use detee_sgx::RaTlsConfig;
use rustls::pki_types::CertificateDer;
use std::{pin::Pin, sync::Arc};
use tokio::sync::broadcast::Sender;
use tokio::time::interval;
use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
@ -76,7 +75,6 @@ impl NodeServer {
let tls_acceptor = tls_acceptor.clone();
let svc = svc.clone();
state.declare_myself_public().await;
let state = state.clone();
tokio::spawn(async move {
let mut certificates = Vec::new();
@ -108,8 +106,6 @@ impl NodeServer {
}));
let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
let ip = addr.ip().to_string();
state.add_conn(&ip).await;
if let Err(e) = http
.serve_connection(
TokioIo::new(conn),
@ -121,7 +117,6 @@ impl NodeServer {
{
println!("Error serving connection: {}", e);
}
state.delete_conn(&ip).await;
});
}
}
@ -156,17 +151,18 @@ impl Update for NodeServer {
let my_ip = self.state.get_my_ip().await;
let stream = async_stream::stream! {
state.declare_myself_public().await;
state.add_conn(&remote_ip).await;
let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect();
for update in known_nodes {
yield Ok(update);
}
let error_status: Status; // Gets initialized inside loop
let mut timeout = interval(state.get_timeout());
let error_status: Status;
loop {
tokio::select! {
Some(msg) = inbound.next() => {
timeout = interval(state.get_timeout());
match msg {
Ok(update) => {
if update.ip == remote_ip {
@ -178,7 +174,7 @@ impl Update for NodeServer {
}
if state.process_node_update(update.clone().into()).await {
// If process update returns true, the update must be forwarded
// if process update returns true, the update must be forwarded
if tx.send((remote_ip.clone(), update).into()).is_err() {
println!("Tokio broadcast receivers had an issue consuming the channel");
};
@ -192,21 +188,18 @@ impl Update for NodeServer {
}
Ok(update) = rx.recv() => {
if update.sender_ip != remote_ip {
// Don't bounce back the update we just received
// don't bounce back the update we just received
yield Ok(update.update);
}
// TODO: check if disconnect client if too many connections are active
// disconnect client if too many connections are active
if tx.receiver_count() > 9 {
error_status = Status::internal("Already have too many clients. Connect to another server.");
break;
}
}
_ = timeout.tick() => {
error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs()));
break;
}
}
}
state.delete_conn(&remote_ip).await;
yield Err(error_status);
};

@ -51,8 +51,7 @@ pub async fn heartbeat(
loop {
interval.tick().await;
println!("Heartbeat...");
state.update_keepalive().await;
state.remove_inactive_nodes().await;
state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await;
let connected_ips = state.get_connected_ips().await;
println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips);
let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into());
@ -128,7 +127,7 @@ async fn main() {
let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4
println!("Starting on IP {}", my_ip);
let state = Arc::new(State::new(my_ip.clone(), HEARTBEAT_INTERVAL * 3));
let state = Arc::new(State::new(my_ip.clone()));
let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await);
let mut tasks = JoinSet::new();