Compare commits

..

No commits in common. "4a6e8c4c73dcc9dd99e875438f6c7a2d77822bda" and "eceafd9de8fe89fbdd49bcfed0553c22a4e5869d" have entirely different histories.

4 changed files with 28 additions and 62 deletions

@ -94,15 +94,14 @@ pub struct State {
my_ip: String, my_ip: String,
nodes: RwLock<HashMap<IP, NodeInfo>>, nodes: RwLock<HashMap<IP, NodeInfo>>,
conns: RwLock<HashSet<IP>>, conns: RwLock<HashSet<IP>>,
timeout: u64,
} }
impl State { impl State {
pub fn new(my_ip: String, timeout: u64) -> Self { pub fn new(my_ip: String) -> Self {
let mut nodes = HashMap::new(); let mut nodes = HashMap::new();
let my_info = NodeInfo::load(); let my_info = NodeInfo::load();
nodes.insert(my_ip.clone(), my_info); nodes.insert(my_ip.clone(), my_info);
Self { my_ip, timeout, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
} }
pub async fn add_conn(&self, ip: &str) { pub async fn add_conn(&self, ip: &str) {
@ -199,10 +198,6 @@ impl State {
self.my_ip.clone() self.my_ip.clone()
} }
pub fn get_timeout(&self) -> Duration {
Duration::from_secs(self.timeout)
}
pub async fn get_my_info(&self) -> NodeInfo { pub async fn get_my_info(&self) -> NodeInfo {
let nodes = self.nodes.read().await; let nodes = self.nodes.read().await;
nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty()) nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty())
@ -210,7 +205,7 @@ impl State {
pub async fn get_connected_ips(&self) -> Vec<String> { pub async fn get_connected_ips(&self) -> Vec<String> {
let conns = self.conns.read().await; let conns = self.conns.read().await;
conns.iter().cloned().collect() conns.iter().map(|n| n.clone()).collect()
} }
// returns a random node that does not have an active connection // returns a random node that does not have an active connection
@ -219,10 +214,7 @@ impl State {
let conn_ips = self.get_connected_ips().await; let conn_ips = self.get_connected_ips().await;
let nodes = self.nodes.read().await; let nodes = self.nodes.read().await;
if nodes.is_empty() { let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len();
return None;
}
let skip = OsRng.next_u64() as usize % nodes.len();
nodes nodes
.keys() .keys()
.map(|ip| ip.to_string()) .map(|ip| ip.to_string())
@ -231,15 +223,6 @@ impl State {
.nth(skip) .nth(skip)
} }
pub async fn update_keepalive(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.keepalive = SystemTime::now();
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
/// This returns true if the update should be further forwarded /// This returns true if the update should be further forwarded
/// For example, we never forward our own updates that came back /// For example, we never forward our own updates that came back
pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool { pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool {
@ -256,15 +239,13 @@ impl State {
is_update_new && !is_update_mine is_update_new && !is_update_mine
} }
pub async fn remove_inactive_nodes(&self) { pub async fn remove_inactive_nodes(&self, max_age: u64) {
let mut nodes = self.nodes.write().await; let mut nodes = self.nodes.write().await;
// TODO: Check if it is possible to corrupt SGX system time // TODO: Check if it is possible to corrupt SGX system time
// TODO: Double check if we need to cleanup nodes
let now = SystemTime::now(); let now = SystemTime::now();
let max_age = self.timeout;
nodes.retain(|_, n| { nodes.retain(|_, n| {
let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs(); let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs();
age < max_age age <= max_age
}); });
} }
} }

@ -104,26 +104,19 @@ impl ConnManager {
} }
}); });
let to_stream = match client.get_updates(rx_stream).await { let response = client.get_updates(rx_stream).await;
Ok(response) => { if let Err(e) = response {
let stream = response.into_inner();
stream.timeout(self.state.get_timeout())
}
Err(e) => {
println!("Error connecting to {remote_ip}: {e}"); println!("Error connecting to {remote_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") { if e.to_string().contains("QuoteVerifyError") {
self.state.increase_net_attacks().await; self.state.increase_net_attacks().await;
} }
return Err(e.into()); return Err(e.into());
} }
}; let mut resp_stream = response.unwrap().into_inner();
// TODO: Check if immediately sending our info as a network update to everybody works // Immediately send our info as a network update
//let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into()); let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
while let Some(update) = resp_stream.message().await? {
tokio::pin!(to_stream);
let mut updates = to_stream.take_while(Result::is_ok);
while let Some(Ok(Ok(update))) = updates.next().await {
// Update the entire network in case the information is new // Update the entire network in case the information is new
if self.state.process_node_update(update.clone().into()).await { if self.state.process_node_update(update.clone().into()).await {
// If process update returns true, the update must be forwarded // If process update returns true, the update must be forwarded
@ -133,7 +126,7 @@ impl ConnManager {
} }
} }
Err(std::io::Error::new(std::io::ErrorKind::Other, "Updates interrupted").into()) Ok(())
} }
} }

@ -7,7 +7,6 @@ use detee_sgx::RaTlsConfig;
use rustls::pki_types::CertificateDer; use rustls::pki_types::CertificateDer;
use std::{pin::Pin, sync::Arc}; use std::{pin::Pin, sync::Arc};
use tokio::sync::broadcast::Sender; use tokio::sync::broadcast::Sender;
use tokio::time::interval;
use tokio_stream::{Stream, StreamExt}; use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming}; use tonic::{Request, Response, Status, Streaming};
@ -76,7 +75,6 @@ impl NodeServer {
let tls_acceptor = tls_acceptor.clone(); let tls_acceptor = tls_acceptor.clone();
let svc = svc.clone(); let svc = svc.clone();
state.declare_myself_public().await;
let state = state.clone(); let state = state.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut certificates = Vec::new(); let mut certificates = Vec::new();
@ -108,8 +106,6 @@ impl NodeServer {
})); }));
let svc = ServiceBuilder::new().layer(extension_layer).service(svc); let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
let ip = addr.ip().to_string();
state.add_conn(&ip).await;
if let Err(e) = http if let Err(e) = http
.serve_connection( .serve_connection(
TokioIo::new(conn), TokioIo::new(conn),
@ -121,7 +117,6 @@ impl NodeServer {
{ {
println!("Error serving connection: {}", e); println!("Error serving connection: {}", e);
} }
state.delete_conn(&ip).await;
}); });
} }
} }
@ -156,17 +151,18 @@ impl Update for NodeServer {
let my_ip = self.state.get_my_ip().await; let my_ip = self.state.get_my_ip().await;
let stream = async_stream::stream! { let stream = async_stream::stream! {
state.declare_myself_public().await;
state.add_conn(&remote_ip).await;
let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect(); let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect();
for update in known_nodes { for update in known_nodes {
yield Ok(update); yield Ok(update);
} }
let error_status: Status; // Gets initialized inside loop let error_status: Status;
let mut timeout = interval(state.get_timeout());
loop { loop {
tokio::select! { tokio::select! {
Some(msg) = inbound.next() => { Some(msg) = inbound.next() => {
timeout = interval(state.get_timeout());
match msg { match msg {
Ok(update) => { Ok(update) => {
if update.ip == remote_ip { if update.ip == remote_ip {
@ -178,7 +174,7 @@ impl Update for NodeServer {
} }
if state.process_node_update(update.clone().into()).await { if state.process_node_update(update.clone().into()).await {
// If process update returns true, the update must be forwarded // if process update returns true, the update must be forwarded
if tx.send((remote_ip.clone(), update).into()).is_err() { if tx.send((remote_ip.clone(), update).into()).is_err() {
println!("Tokio broadcast receivers had an issue consuming the channel"); println!("Tokio broadcast receivers had an issue consuming the channel");
}; };
@ -192,21 +188,18 @@ impl Update for NodeServer {
} }
Ok(update) = rx.recv() => { Ok(update) = rx.recv() => {
if update.sender_ip != remote_ip { if update.sender_ip != remote_ip {
// Don't bounce back the update we just received // don't bounce back the update we just received
yield Ok(update.update); yield Ok(update.update);
} }
// TODO: check if disconnect client if too many connections are active // disconnect client if too many connections are active
if tx.receiver_count() > 9 { if tx.receiver_count() > 9 {
error_status = Status::internal("Already have too many clients. Connect to another server."); error_status = Status::internal("Already have too many clients. Connect to another server.");
break; break;
} }
} }
_ = timeout.tick() => {
error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs()));
break;
}
} }
} }
state.delete_conn(&remote_ip).await;
yield Err(error_status); yield Err(error_status);
}; };

@ -51,8 +51,7 @@ pub async fn heartbeat(
loop { loop {
interval.tick().await; interval.tick().await;
println!("Heartbeat..."); println!("Heartbeat...");
state.update_keepalive().await; state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await;
state.remove_inactive_nodes().await;
let connected_ips = state.get_connected_ips().await; let connected_ips = state.get_connected_ips().await;
println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips); println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips);
let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into()); let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into());
@ -128,7 +127,7 @@ async fn main() {
let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4 let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4
println!("Starting on IP {}", my_ip); println!("Starting on IP {}", my_ip);
let state = Arc::new(State::new(my_ip.clone(), HEARTBEAT_INTERVAL * 3)); let state = Arc::new(State::new(my_ip.clone()));
let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await); let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await);
let mut tasks = JoinSet::new(); let mut tasks = JoinSet::new();