connections stability fixes
Signed-off-by: Valentyn Faychuk <valy@detee.ltd>
This commit is contained in:
		
							parent
							
								
									eceafd9de8
								
							
						
					
					
						commit
						01e90f874c
					
				| @ -94,14 +94,15 @@ pub struct State { | |||||||
|     my_ip: String, |     my_ip: String, | ||||||
|     nodes: RwLock<HashMap<IP, NodeInfo>>, |     nodes: RwLock<HashMap<IP, NodeInfo>>, | ||||||
|     conns: RwLock<HashSet<IP>>, |     conns: RwLock<HashSet<IP>>, | ||||||
|  |     timeout: u64, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl State { | impl State { | ||||||
|     pub fn new(my_ip: String) -> Self { |     pub fn new(my_ip: String, timeout: u64) -> Self { | ||||||
|         let mut nodes = HashMap::new(); |         let mut nodes = HashMap::new(); | ||||||
|         let my_info = NodeInfo::load(); |         let my_info = NodeInfo::load(); | ||||||
|         nodes.insert(my_ip.clone(), my_info); |         nodes.insert(my_ip.clone(), my_info); | ||||||
|         Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } |         Self { my_ip, timeout, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub async fn add_conn(&self, ip: &str) { |     pub async fn add_conn(&self, ip: &str) { | ||||||
| @ -198,6 +199,10 @@ impl State { | |||||||
|         self.my_ip.clone() |         self.my_ip.clone() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     pub fn get_timeout(&self) -> Duration { | ||||||
|  |         Duration::from_secs(self.timeout) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     pub async fn get_my_info(&self) -> NodeInfo { |     pub async fn get_my_info(&self) -> NodeInfo { | ||||||
|         let nodes = self.nodes.read().await; |         let nodes = self.nodes.read().await; | ||||||
|         nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty()) |         nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty()) | ||||||
| @ -205,7 +210,7 @@ impl State { | |||||||
| 
 | 
 | ||||||
|     pub async fn get_connected_ips(&self) -> Vec<String> { |     pub async fn get_connected_ips(&self) -> Vec<String> { | ||||||
|         let conns = self.conns.read().await; |         let conns = self.conns.read().await; | ||||||
|         conns.iter().map(|n| n.clone()).collect() |         conns.iter().cloned().collect() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // returns a random node that does not have an active connection
 |     // returns a random node that does not have an active connection
 | ||||||
| @ -239,13 +244,15 @@ impl State { | |||||||
|         is_update_new && !is_update_mine |         is_update_new && !is_update_mine | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub async fn remove_inactive_nodes(&self, max_age: u64) { |     pub async fn remove_inactive_nodes(&self) { | ||||||
|         let mut nodes = self.nodes.write().await; |         let mut nodes = self.nodes.write().await; | ||||||
|         // TODO: Check if it is possible to corrupt SGX system time
 |         // TODO: Check if it is possible to corrupt SGX system time
 | ||||||
|  |         // TODO: Double check if we need to cleanup nodes
 | ||||||
|         let now = SystemTime::now(); |         let now = SystemTime::now(); | ||||||
|  |         let max_age = self.timeout; | ||||||
|         nodes.retain(|_, n| { |         nodes.retain(|_, n| { | ||||||
|             let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs(); |             let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs(); | ||||||
|             age <= max_age |             age < max_age | ||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -104,19 +104,26 @@ impl ConnManager { | |||||||
|                 } |                 } | ||||||
|             }); |             }); | ||||||
| 
 | 
 | ||||||
|         let response = client.get_updates(rx_stream).await; |         let to_stream = match client.get_updates(rx_stream).await { | ||||||
|         if let Err(e) = response { |             Ok(response) => { | ||||||
|             println!("Error connecting to {remote_ip}: {e}"); |                 let stream = response.into_inner(); | ||||||
|             if e.to_string().contains("QuoteVerifyError") { |                 stream.timeout(self.state.get_timeout()) | ||||||
|                 self.state.increase_net_attacks().await; |  | ||||||
|             } |             } | ||||||
|             return Err(e.into()); |             Err(e) => { | ||||||
|         } |                 println!("Error connecting to {remote_ip}: {e}"); | ||||||
|         let mut resp_stream = response.unwrap().into_inner(); |                 if e.to_string().contains("QuoteVerifyError") { | ||||||
|  |                     self.state.increase_net_attacks().await; | ||||||
|  |                 } | ||||||
|  |                 return Err(e.into()); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
| 
 | 
 | ||||||
|         // Immediately send our info as a network update
 |         // TODO: Check if immediately sending our info as a network update to everybody works
 | ||||||
|         let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into()); |         //let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
 | ||||||
|         while let Some(update) = resp_stream.message().await? { | 
 | ||||||
|  |         tokio::pin!(to_stream); | ||||||
|  |         let mut updates = to_stream.take_while(Result::is_ok); | ||||||
|  |         while let Some(Ok(Ok(update))) = updates.next().await { | ||||||
|             // Update the entire network in case the information is new
 |             // Update the entire network in case the information is new
 | ||||||
|             if self.state.process_node_update(update.clone().into()).await { |             if self.state.process_node_update(update.clone().into()).await { | ||||||
|                 // If process update returns true, the update must be forwarded
 |                 // If process update returns true, the update must be forwarded
 | ||||||
| @ -126,7 +133,7 @@ impl ConnManager { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Err(std::io::Error::new(std::io::ErrorKind::Other, "Updates interrupted").into()) | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ use detee_sgx::RaTlsConfig; | |||||||
| use rustls::pki_types::CertificateDer; | use rustls::pki_types::CertificateDer; | ||||||
| use std::{pin::Pin, sync::Arc}; | use std::{pin::Pin, sync::Arc}; | ||||||
| use tokio::sync::broadcast::Sender; | use tokio::sync::broadcast::Sender; | ||||||
|  | use tokio::time::interval; | ||||||
| use tokio_stream::{Stream, StreamExt}; | use tokio_stream::{Stream, StreamExt}; | ||||||
| use tonic::{Request, Response, Status, Streaming}; | use tonic::{Request, Response, Status, Streaming}; | ||||||
| 
 | 
 | ||||||
| @ -75,6 +76,7 @@ impl NodeServer { | |||||||
|             let tls_acceptor = tls_acceptor.clone(); |             let tls_acceptor = tls_acceptor.clone(); | ||||||
|             let svc = svc.clone(); |             let svc = svc.clone(); | ||||||
| 
 | 
 | ||||||
|  |             state.declare_myself_public().await; | ||||||
|             let state = state.clone(); |             let state = state.clone(); | ||||||
|             tokio::spawn(async move { |             tokio::spawn(async move { | ||||||
|                 let mut certificates = Vec::new(); |                 let mut certificates = Vec::new(); | ||||||
| @ -106,6 +108,8 @@ impl NodeServer { | |||||||
|                     })); |                     })); | ||||||
|                 let svc = ServiceBuilder::new().layer(extension_layer).service(svc); |                 let svc = ServiceBuilder::new().layer(extension_layer).service(svc); | ||||||
| 
 | 
 | ||||||
|  |                 let ip = addr.ip().to_string(); | ||||||
|  |                 state.add_conn(&ip).await; | ||||||
|                 if let Err(e) = http |                 if let Err(e) = http | ||||||
|                     .serve_connection( |                     .serve_connection( | ||||||
|                         TokioIo::new(conn), |                         TokioIo::new(conn), | ||||||
| @ -117,6 +121,7 @@ impl NodeServer { | |||||||
|                 { |                 { | ||||||
|                     println!("Error serving connection: {}", e); |                     println!("Error serving connection: {}", e); | ||||||
|                 } |                 } | ||||||
|  |                 state.delete_conn(&ip).await; | ||||||
|             }); |             }); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -151,18 +156,17 @@ impl Update for NodeServer { | |||||||
|         let my_ip = self.state.get_my_ip().await; |         let my_ip = self.state.get_my_ip().await; | ||||||
| 
 | 
 | ||||||
|         let stream = async_stream::stream! { |         let stream = async_stream::stream! { | ||||||
|             state.declare_myself_public().await; |  | ||||||
|             state.add_conn(&remote_ip).await; |  | ||||||
| 
 |  | ||||||
|             let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect(); |             let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect(); | ||||||
|             for update in known_nodes { |             for update in known_nodes { | ||||||
|                 yield Ok(update); |                 yield Ok(update); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             let error_status: Status; |             let error_status: Status; // Gets initialized inside loop
 | ||||||
|  |             let mut timeout = interval(state.get_timeout()); | ||||||
|             loop { |             loop { | ||||||
|                 tokio::select! { |                 tokio::select! { | ||||||
|                     Some(msg) = inbound.next() => { |                     Some(msg) = inbound.next() => { | ||||||
|  |                         timeout = interval(state.get_timeout()); | ||||||
|                         match msg { |                         match msg { | ||||||
|                             Ok(update) => { |                             Ok(update) => { | ||||||
|                                 if update.ip == remote_ip { |                                 if update.ip == remote_ip { | ||||||
| @ -174,7 +178,7 @@ impl Update for NodeServer { | |||||||
|                                 } |                                 } | ||||||
| 
 | 
 | ||||||
|                                 if state.process_node_update(update.clone().into()).await { |                                 if state.process_node_update(update.clone().into()).await { | ||||||
|                                     // if process update returns true, the update must be forwarded
 |                                     // If process update returns true, the update must be forwarded
 | ||||||
|                                     if tx.send((remote_ip.clone(), update).into()).is_err() { |                                     if tx.send((remote_ip.clone(), update).into()).is_err() { | ||||||
|                                         println!("Tokio broadcast receivers had an issue consuming the channel"); |                                         println!("Tokio broadcast receivers had an issue consuming the channel"); | ||||||
|                                     }; |                                     }; | ||||||
| @ -188,18 +192,21 @@ impl Update for NodeServer { | |||||||
|                     } |                     } | ||||||
|                     Ok(update) = rx.recv() => { |                     Ok(update) = rx.recv() => { | ||||||
|                         if update.sender_ip != remote_ip { |                         if update.sender_ip != remote_ip { | ||||||
|                             // don't bounce back the update we just received
 |                             // Don't bounce back the update we just received
 | ||||||
|                             yield Ok(update.update); |                             yield Ok(update.update); | ||||||
|                         } |                         } | ||||||
|                         // disconnect client if too many connections are active
 |                         // TODO: check if disconnect client if too many connections are active
 | ||||||
|                         if tx.receiver_count() > 9 { |                         if tx.receiver_count() > 9 { | ||||||
|                             error_status = Status::internal("Already have too many clients. Connect to another server."); |                             error_status = Status::internal("Already have too many clients. Connect to another server."); | ||||||
|                             break; |                             break; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |                     _ = timeout.tick() => { | ||||||
|  |                         error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs())); | ||||||
|  |                         break; | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             state.delete_conn(&remote_ip).await; |  | ||||||
|             yield Err(error_status); |             yield Err(error_status); | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ pub async fn heartbeat( | |||||||
|     loop { |     loop { | ||||||
|         interval.tick().await; |         interval.tick().await; | ||||||
|         println!("Heartbeat..."); |         println!("Heartbeat..."); | ||||||
|         state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await; |         state.remove_inactive_nodes().await; | ||||||
|         let connected_ips = state.get_connected_ips().await; |         let connected_ips = state.get_connected_ips().await; | ||||||
|         println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips); |         println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips); | ||||||
|         let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into()); |         let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into()); | ||||||
| @ -127,7 +127,7 @@ async fn main() { | |||||||
|     let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4
 |     let my_ip = resolve_my_ipv4().await; // Guaranteed to be correct IPv4
 | ||||||
|     println!("Starting on IP {}", my_ip); |     println!("Starting on IP {}", my_ip); | ||||||
| 
 | 
 | ||||||
|     let state = Arc::new(State::new(my_ip.clone())); |     let state = Arc::new(State::new(my_ip.clone(), HEARTBEAT_INTERVAL * 3)); | ||||||
|     let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await); |     let sol_client = Arc::new(init_token(state.clone(), ra_cfg.clone()).await); | ||||||
| 
 | 
 | ||||||
|     let mut tasks = JoinSet::new(); |     let mut tasks = JoinSet::new(); | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user