connections stability fixes

Signed-off-by: Valentyn Faychuk <valy@detee.ltd>
This commit is contained in:
Valentyn Faychuk 2024-12-24 17:46:29 +00:00
parent eceafd9de8
commit 01e90f874c
Signed by: valy
GPG Key ID: F1AB995E20FEADC5
4 changed files with 48 additions and 27 deletions

@ -94,14 +94,15 @@ pub struct State {
my_ip: String,
nodes: RwLock<HashMap<IP, NodeInfo>>,
conns: RwLock<HashSet<IP>>,
timeout: u64,
}
impl State {
pub fn new(my_ip: String) -> Self {
pub fn new(my_ip: String, timeout: u64) -> Self {
let mut nodes = HashMap::new();
let my_info = NodeInfo::load();
nodes.insert(my_ip.clone(), my_info);
Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
Self { my_ip, timeout, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
}
pub async fn add_conn(&self, ip: &str) {
@ -198,6 +199,10 @@ impl State {
self.my_ip.clone()
}
pub fn get_timeout(&self) -> Duration {
Duration::from_secs(self.timeout)
}
pub async fn get_my_info(&self) -> NodeInfo {
let nodes = self.nodes.read().await;
nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty())
@ -205,7 +210,7 @@ impl State {
pub async fn get_connected_ips(&self) -> Vec<String> {
let conns = self.conns.read().await;
conns.iter().map(|n| n.clone()).collect()
conns.iter().cloned().collect()
}
// returns a random node that does not have an active connection
@ -239,13 +244,15 @@ impl State {
is_update_new && !is_update_mine
}
pub async fn remove_inactive_nodes(&self, max_age: u64) {
pub async fn remove_inactive_nodes(&self) {
let mut nodes = self.nodes.write().await;
// TODO: Check if it is possible to corrupt SGX system time
// TODO: Double check if we need to cleanup nodes
let now = SystemTime::now();
let max_age = self.timeout;
nodes.retain(|_, n| {
let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs();
age <= max_age
age < max_age
});
}
}

@ -104,19 +104,26 @@ impl ConnManager {
}
});
let response = client.get_updates(rx_stream).await;
if let Err(e) = response {
println!("Error connecting to {remote_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") {
self.state.increase_net_attacks().await;
let to_stream = match client.get_updates(rx_stream).await {
Ok(response) => {
let stream = response.into_inner();
stream.timeout(self.state.get_timeout())
}
return Err(e.into());
}
let mut resp_stream = response.unwrap().into_inner();
Err(e) => {
println!("Error connecting to {remote_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") {
self.state.increase_net_attacks().await;
}
return Err(e.into());
}
};
// Immediately send our info as a network update
let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
while let Some(update) = resp_stream.message().await? {
// TODO: Check if immediately sending our info as a network update to everybody works
//let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
tokio::pin!(to_stream);
let mut updates = to_stream.take_while(Result::is_ok);
while let Some(Ok(Ok(update))) = updates.next().await {
// Update the entire network in case the information is new
if self.state.process_node_update(update.clone().into()).await {
// If process update returns true, the update must be forwarded
@ -126,7 +133,7 @@ impl ConnManager {
}
}
Ok(())
Err(std::io::Error::new(std::io::ErrorKind::Other, "Updates interrupted").into())
}
}

@ -7,6 +7,7 @@ use detee_sgx::RaTlsConfig;
use rustls::pki_types::CertificateDer;
use std::{pin::Pin, sync::Arc};
use tokio::sync::broadcast::Sender;
use tokio::time::interval;
use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
@ -75,6 +76,7 @@ impl NodeServer {
let tls_acceptor = tls_acceptor.clone();
let svc = svc.clone();
state.declare_myself_public().await;
let state = state.clone();
tokio::spawn(async move {
let mut certificates = Vec::new();
@ -106,6 +108,8 @@ impl NodeServer {
}));
let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
let ip = addr.ip().to_string();
state.add_conn(&ip).await;
if let Err(e) = http
.serve_connection(
TokioIo::new(conn),
@ -117,6 +121,7 @@ impl NodeServer {
{
println!("Error serving connection: {}", e);
}
state.delete_conn(&ip).await;
});
}
}
@ -151,18 +156,17 @@ impl Update for NodeServer {
let my_ip = self.state.get_my_ip().await;
let stream = async_stream::stream! {
state.declare_myself_public().await;
state.add_conn(&remote_ip).await;
let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect();
for update in known_nodes {
yield Ok(update);
}
let error_status: Status;
let error_status: Status; // Gets initialized inside loop
let mut timeout = interval(state.get_timeout());
loop {
tokio::select! {
Some(msg) = inbound.next() => {
timeout = interval(state.get_timeout());
match msg {
Ok(update) => {
if update.ip == remote_ip {
@ -174,7 +178,7 @@ impl Update for NodeServer {
}
if state.process_node_update(update.clone().into()).await {
// if process update returns true, the update must be forwarded
// If process update returns true, the update must be forwarded
if tx.send((remote_ip.clone(), update).into()).is_err() {
println!("Tokio broadcast receivers had an issue consuming the channel");
};
@ -188,18 +192,21 @@ impl Update for NodeServer {
}
Ok(update) = rx.recv() => {
if update.sender_ip != remote_ip {
// don't bounce back the update we just received
// Don't bounce back the update we just received
yield Ok(update.update);
}
// disconnect client if too many connections are active
// TODO: check if disconnect client if too many connections are active
if tx.receiver_count() > 9 {
error_status = Status::internal("Already have too many clients. Connect to another server.");
break;
}
}
_ = timeout.tick() => {
error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs()));
break;
}
}
}
state.delete_conn(&remote_ip).await;
yield Err(error_status);
};

@ -51,7 +51,7 @@ pub async fn heartbeat(
loop {
interval.tick().await;
println!("Heartbeat...");
state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await;
state.remove_inactive_nodes().await;
let connected_ips = state.get_connected_ips().await;
println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips);
let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into());
@ -127,7 +127,7 @@ async fn main() {
let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4
println!("Starting on IP {}", my_ip);
let state = Arc::new(State::new(my_ip.clone()));
let state = Arc::new(State::new(my_ip.clone(), HEARTBEAT_INTERVAL * 3));
let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await);
let mut tasks = JoinSet::new();