From 01e90f874cb5fbdb2122459cf3f3d57e986b38c2 Mon Sep 17 00:00:00 2001 From: Valentyn Faychuk Date: Tue, 24 Dec 2024 17:46:29 +0000 Subject: [PATCH] connections stability fixes Signed-off-by: Valentyn Faychuk --- src/datastore.rs | 17 ++++++++++++----- src/grpc/client.rs | 31 +++++++++++++++++++------------ src/grpc/server.rs | 23 +++++++++++++++-------- src/main.rs | 4 ++-- 4 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/datastore.rs b/src/datastore.rs index 9626b19..c6162a0 100644 --- a/src/datastore.rs +++ b/src/datastore.rs @@ -94,14 +94,15 @@ pub struct State { my_ip: String, nodes: RwLock>, conns: RwLock>, + timeout: u64, } impl State { - pub fn new(my_ip: String) -> Self { + pub fn new(my_ip: String, timeout: u64) -> Self { let mut nodes = HashMap::new(); let my_info = NodeInfo::load(); nodes.insert(my_ip.clone(), my_info); - Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } + Self { my_ip, timeout, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } } pub async fn add_conn(&self, ip: &str) { @@ -198,6 +199,10 @@ impl State { self.my_ip.clone() } + pub fn get_timeout(&self) -> Duration { + Duration::from_secs(self.timeout) + } + pub async fn get_my_info(&self) -> NodeInfo { let nodes = self.nodes.read().await; nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty()) @@ -205,7 +210,7 @@ impl State { pub async fn get_connected_ips(&self) -> Vec { let conns = self.conns.read().await; - conns.iter().map(|n| n.clone()).collect() + conns.iter().cloned().collect() } // returns a random node that does not have an active connection @@ -239,13 +244,15 @@ impl State { is_update_new && !is_update_mine } - pub async fn remove_inactive_nodes(&self, max_age: u64) { + pub async fn remove_inactive_nodes(&self) { let mut nodes = self.nodes.write().await; // TODO: Check if it is possible to corrupt SGX system time + // TODO: Double check if we need to cleanup nodes let now = SystemTime::now(); + let max_age = self.timeout; nodes.retain(|_, n| { let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs(); - age <= max_age + age < max_age }); } } diff --git a/src/grpc/client.rs b/src/grpc/client.rs index 43c9ba0..df81563 100644 --- a/src/grpc/client.rs +++ b/src/grpc/client.rs @@ -104,19 +104,26 @@ impl ConnManager { } }); - let response = client.get_updates(rx_stream).await; - if let Err(e) = response { - println!("Error connecting to {remote_ip}: {e}"); - if e.to_string().contains("QuoteVerifyError") { - self.state.increase_net_attacks().await; + let to_stream = match client.get_updates(rx_stream).await { + Ok(response) => { + let stream = response.into_inner(); + stream.timeout(self.state.get_timeout()) } - return Err(e.into()); - } - let mut resp_stream = response.unwrap().into_inner(); + Err(e) => { + println!("Error connecting to {remote_ip}: {e}"); + if e.to_string().contains("QuoteVerifyError") { + self.state.increase_net_attacks().await; + } + return Err(e.into()); + } + }; - // Immediately send our info as a network update - let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into()); - while let Some(update) = resp_stream.message().await? { + // TODO: Check if immediately sending our info as a network update to everybody works + //let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into()); + + tokio::pin!(to_stream); + let mut updates = to_stream.take_while(Result::is_ok); + while let Some(Ok(Ok(update))) = updates.next().await { // Update the entire network in case the information is new if self.state.process_node_update(update.clone().into()).await { // If process update returns true, the update must be forwarded @@ -126,7 +133,7 @@ impl ConnManager { } } - Ok(()) + Err(std::io::Error::new(std::io::ErrorKind::Other, "Updates interrupted").into()) } } diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 6746f0c..9f60512 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -7,6 +7,7 @@ use detee_sgx::RaTlsConfig; use rustls::pki_types::CertificateDer; use std::{pin::Pin, sync::Arc}; use tokio::sync::broadcast::Sender; +use tokio::time::interval; use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; @@ -75,6 +76,7 @@ impl NodeServer { let tls_acceptor = tls_acceptor.clone(); let svc = svc.clone(); + state.declare_myself_public().await; let state = state.clone(); tokio::spawn(async move { let mut certificates = Vec::new(); @@ -106,6 +108,8 @@ impl NodeServer { })); let svc = ServiceBuilder::new().layer(extension_layer).service(svc); + let ip = addr.ip().to_string(); + state.add_conn(&ip).await; if let Err(e) = http .serve_connection( TokioIo::new(conn), @@ -117,6 +121,7 @@ impl NodeServer { { println!("Error serving connection: {}", e); } + state.delete_conn(&ip).await; }); } } @@ -151,18 +156,17 @@ impl Update for NodeServer { let my_ip = self.state.get_my_ip().await; let stream = async_stream::stream! { - state.declare_myself_public().await; - state.add_conn(&remote_ip).await; - let known_nodes: Vec = state.get_nodes().await.into_iter().map(Into::into).collect(); for update in known_nodes { yield Ok(update); } - let error_status: Status; + let error_status: Status; // Gets initialized inside loop + let mut timeout = interval(state.get_timeout()); loop { tokio::select! { Some(msg) = inbound.next() => { + timeout = interval(state.get_timeout()); match msg { Ok(update) => { if update.ip == remote_ip { @@ -174,7 +178,7 @@ impl Update for NodeServer { } if state.process_node_update(update.clone().into()).await { - // if process update returns true, the update must be forwarded + // If process update returns true, the update must be forwarded if tx.send((remote_ip.clone(), update).into()).is_err() { println!("Tokio broadcast receivers had an issue consuming the channel"); }; @@ -188,18 +192,21 @@ impl Update for NodeServer { } Ok(update) = rx.recv() => { if update.sender_ip != remote_ip { - // don't bounce back the update we just received + // Don't bounce back the update we just received yield Ok(update.update); } - // disconnect client if too many connections are active + // TODO: check if disconnect client if too many connections are active if tx.receiver_count() > 9 { error_status = Status::internal("Already have too many clients. Connect to another server."); break; } } + _ = timeout.tick() => { + error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs())); + break; + } } } - state.delete_conn(&remote_ip).await; yield Err(error_status); }; diff --git a/src/main.rs b/src/main.rs index 82b616c..90b14f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,7 +51,7 @@ pub async fn heartbeat( loop { interval.tick().await; println!("Heartbeat..."); - state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await; + state.remove_inactive_nodes().await; let connected_ips = state.get_connected_ips().await; println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips); let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into()); @@ -127,7 +127,7 @@ async fn main() { let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4 println!("Starting on IP {}", my_ip); - let state = Arc::new(State::new(my_ip.clone())); + let state = Arc::new(State::new(my_ip.clone(), HEARTBEAT_INTERVAL * 3)); let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await); let mut tasks = JoinSet::new();