connections stability fixes
Signed-off-by: Valentyn Faychuk <valy@detee.ltd>
This commit is contained in:
parent
eceafd9de8
commit
01e90f874c
@ -94,14 +94,15 @@ pub struct State {
|
||||
my_ip: String,
|
||||
nodes: RwLock<HashMap<IP, NodeInfo>>,
|
||||
conns: RwLock<HashSet<IP>>,
|
||||
timeout: u64,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(my_ip: String) -> Self {
|
||||
pub fn new(my_ip: String, timeout: u64) -> Self {
|
||||
let mut nodes = HashMap::new();
|
||||
let my_info = NodeInfo::load();
|
||||
nodes.insert(my_ip.clone(), my_info);
|
||||
Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
|
||||
Self { my_ip, timeout, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
|
||||
}
|
||||
|
||||
pub async fn add_conn(&self, ip: &str) {
|
||||
@ -198,6 +199,10 @@ impl State {
|
||||
self.my_ip.clone()
|
||||
}
|
||||
|
||||
pub fn get_timeout(&self) -> Duration {
|
||||
Duration::from_secs(self.timeout)
|
||||
}
|
||||
|
||||
pub async fn get_my_info(&self) -> NodeInfo {
|
||||
let nodes = self.nodes.read().await;
|
||||
nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty())
|
||||
@ -205,7 +210,7 @@ impl State {
|
||||
|
||||
pub async fn get_connected_ips(&self) -> Vec<String> {
|
||||
let conns = self.conns.read().await;
|
||||
conns.iter().map(|n| n.clone()).collect()
|
||||
conns.iter().cloned().collect()
|
||||
}
|
||||
|
||||
// returns a random node that does not have an active connection
|
||||
@ -239,13 +244,15 @@ impl State {
|
||||
is_update_new && !is_update_mine
|
||||
}
|
||||
|
||||
pub async fn remove_inactive_nodes(&self, max_age: u64) {
|
||||
pub async fn remove_inactive_nodes(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
// TODO: Check if it is possible to corrupt SGX system time
|
||||
// TODO: Double check if we need to cleanup nodes
|
||||
let now = SystemTime::now();
|
||||
let max_age = self.timeout;
|
||||
nodes.retain(|_, n| {
|
||||
let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs();
|
||||
age <= max_age
|
||||
age < max_age
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -104,19 +104,26 @@ impl ConnManager {
|
||||
}
|
||||
});
|
||||
|
||||
let response = client.get_updates(rx_stream).await;
|
||||
if let Err(e) = response {
|
||||
println!("Error connecting to {remote_ip}: {e}");
|
||||
if e.to_string().contains("QuoteVerifyError") {
|
||||
self.state.increase_net_attacks().await;
|
||||
let to_stream = match client.get_updates(rx_stream).await {
|
||||
Ok(response) => {
|
||||
let stream = response.into_inner();
|
||||
stream.timeout(self.state.get_timeout())
|
||||
}
|
||||
return Err(e.into());
|
||||
}
|
||||
let mut resp_stream = response.unwrap().into_inner();
|
||||
Err(e) => {
|
||||
println!("Error connecting to {remote_ip}: {e}");
|
||||
if e.to_string().contains("QuoteVerifyError") {
|
||||
self.state.increase_net_attacks().await;
|
||||
}
|
||||
return Err(e.into());
|
||||
}
|
||||
};
|
||||
|
||||
// Immediately send our info as a network update
|
||||
let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
|
||||
while let Some(update) = resp_stream.message().await? {
|
||||
// TODO: Check if immediately sending our info as a network update to everybody works
|
||||
//let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
|
||||
|
||||
tokio::pin!(to_stream);
|
||||
let mut updates = to_stream.take_while(Result::is_ok);
|
||||
while let Some(Ok(Ok(update))) = updates.next().await {
|
||||
// Update the entire network in case the information is new
|
||||
if self.state.process_node_update(update.clone().into()).await {
|
||||
// If process update returns true, the update must be forwarded
|
||||
@ -126,7 +133,7 @@ impl ConnManager {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Err(std::io::Error::new(std::io::ErrorKind::Other, "Updates interrupted").into())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@ use detee_sgx::RaTlsConfig;
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
use tokio::sync::broadcast::Sender;
|
||||
use tokio::time::interval;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tonic::{Request, Response, Status, Streaming};
|
||||
|
||||
@ -75,6 +76,7 @@ impl NodeServer {
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let svc = svc.clone();
|
||||
|
||||
state.declare_myself_public().await;
|
||||
let state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut certificates = Vec::new();
|
||||
@ -106,6 +108,8 @@ impl NodeServer {
|
||||
}));
|
||||
let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
|
||||
|
||||
let ip = addr.ip().to_string();
|
||||
state.add_conn(&ip).await;
|
||||
if let Err(e) = http
|
||||
.serve_connection(
|
||||
TokioIo::new(conn),
|
||||
@ -117,6 +121,7 @@ impl NodeServer {
|
||||
{
|
||||
println!("Error serving connection: {}", e);
|
||||
}
|
||||
state.delete_conn(&ip).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -151,18 +156,17 @@ impl Update for NodeServer {
|
||||
let my_ip = self.state.get_my_ip().await;
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
state.declare_myself_public().await;
|
||||
state.add_conn(&remote_ip).await;
|
||||
|
||||
let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect();
|
||||
for update in known_nodes {
|
||||
yield Ok(update);
|
||||
}
|
||||
|
||||
let error_status: Status;
|
||||
let error_status: Status; // Gets initialized inside loop
|
||||
let mut timeout = interval(state.get_timeout());
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = inbound.next() => {
|
||||
timeout = interval(state.get_timeout());
|
||||
match msg {
|
||||
Ok(update) => {
|
||||
if update.ip == remote_ip {
|
||||
@ -174,7 +178,7 @@ impl Update for NodeServer {
|
||||
}
|
||||
|
||||
if state.process_node_update(update.clone().into()).await {
|
||||
// if process update returns true, the update must be forwarded
|
||||
// If process update returns true, the update must be forwarded
|
||||
if tx.send((remote_ip.clone(), update).into()).is_err() {
|
||||
println!("Tokio broadcast receivers had an issue consuming the channel");
|
||||
};
|
||||
@ -188,18 +192,21 @@ impl Update for NodeServer {
|
||||
}
|
||||
Ok(update) = rx.recv() => {
|
||||
if update.sender_ip != remote_ip {
|
||||
// don't bounce back the update we just received
|
||||
// Don't bounce back the update we just received
|
||||
yield Ok(update.update);
|
||||
}
|
||||
// disconnect client if too many connections are active
|
||||
// TODO: check if disconnect client if too many connections are active
|
||||
if tx.receiver_count() > 9 {
|
||||
error_status = Status::internal("Already have too many clients. Connect to another server.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = timeout.tick() => {
|
||||
error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
state.delete_conn(&remote_ip).await;
|
||||
yield Err(error_status);
|
||||
};
|
||||
|
||||
|
@ -51,7 +51,7 @@ pub async fn heartbeat(
|
||||
loop {
|
||||
interval.tick().await;
|
||||
println!("Heartbeat...");
|
||||
state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await;
|
||||
state.remove_inactive_nodes().await;
|
||||
let connected_ips = state.get_connected_ips().await;
|
||||
println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips);
|
||||
let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into());
|
||||
@ -127,7 +127,7 @@ async fn main() {
|
||||
let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4
|
||||
println!("Starting on IP {}", my_ip);
|
||||
|
||||
let state = Arc::new(State::new(my_ip.clone()));
|
||||
let state = Arc::new(State::new(my_ip.clone(), HEARTBEAT_INTERVAL * 3));
|
||||
let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await);
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
|
Loading…
Reference in New Issue
Block a user