switch to tokyo rwlock
This commit is contained in:
		
							parent
							
								
									e5e4109007
								
							
						
					
					
						commit
						eceafd9de8
					
				
							
								
								
									
										151
									
								
								src/datastore.rs
									
									
									
									
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										151
									
								
								src/datastore.rs
									
									
									
									
									
								
							| @ -2,8 +2,8 @@ use crate::persistence::{SealError, SealedFile}; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use serde_with::{serde_as, TimestampSeconds}; | ||||
| use std::collections::{HashMap, HashSet}; | ||||
| use std::sync::RwLock; | ||||
| use std::time::{Duration, SystemTime, UNIX_EPOCH}; | ||||
| use tokio::sync::RwLock; | ||||
| 
 | ||||
| type IP = String; | ||||
| const LOCAL_INFO_FILE: &str = "/host/main/node_info"; | ||||
| @ -104,153 +104,143 @@ impl State { | ||||
|         Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } | ||||
|     } | ||||
| 
 | ||||
|     pub fn add_conn(&self, ip: &str) { | ||||
|         self.increase_mratls_conns(); | ||||
|         self.add_mratls_conn(ip); | ||||
|     pub async fn add_conn(&self, ip: &str) { | ||||
|         self.increase_mratls_conns().await; | ||||
|         self.add_mratls_conn(ip).await; | ||||
|     } | ||||
| 
 | ||||
|     pub fn delete_conn(&self, ip: &str) { | ||||
|         self.decrease_mratls_conns(); | ||||
|         self.delete_mratls_conn(ip); | ||||
|     pub async fn delete_conn(&self, ip: &str) { | ||||
|         self.decrease_mratls_conns().await; | ||||
|         self.delete_mratls_conn(ip).await; | ||||
|     } | ||||
| 
 | ||||
|     fn add_mratls_conn(&self, ip: &str) { | ||||
|         if let Ok(mut conns) = self.conns.write() { | ||||
|     async fn add_mratls_conn(&self, ip: &str) { | ||||
|         let mut conns = self.conns.write().await; | ||||
|         conns.insert(ip.to_string()); | ||||
|     } | ||||
|     } | ||||
| 
 | ||||
|     fn delete_mratls_conn(&self, ip: &str) { | ||||
|         if let Ok(mut conns) = self.conns.write() { | ||||
|     async fn delete_mratls_conn(&self, ip: &str) { | ||||
|         let mut conns = self.conns.write().await; | ||||
|         conns.remove(ip); | ||||
|     } | ||||
| 
 | ||||
|     pub async fn increase_mint_requests(&self) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         if let Some(my_info) = nodes.get(&self.my_ip) { | ||||
|             let mut updated_info = my_info.clone(); | ||||
|             updated_info.mint_requests += 1; | ||||
|             let _ = nodes.insert(self.my_ip.clone(), updated_info); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn increase_mint_requests(&self) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|     pub async fn increase_mints(&self) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         if let Some(my_info) = nodes.get_mut(&self.my_ip) { | ||||
|                 *my_info = NodeInfo { mint_requests: my_info.mint_requests + 1, ..my_info.clone() }; | ||||
|             } | ||||
|             let mut updated_info = my_info.clone(); | ||||
|             updated_info.mints += 1; | ||||
|             let _ = nodes.insert(self.my_ip.clone(), updated_info); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn increase_mints(&self) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|     pub async fn increase_mratls_conns(&self) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         if let Some(my_info) = nodes.get_mut(&self.my_ip) { | ||||
|                 *my_info = NodeInfo { mints: my_info.mints + 1, ..my_info.clone() }; | ||||
|             } | ||||
|             let mut updated_info = my_info.clone(); | ||||
|             updated_info.mratls_conns += 1; | ||||
|             let _ = nodes.insert(self.my_ip.clone(), updated_info); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn increase_mratls_conns(&self) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|             if let Some(my_info) = nodes.get_mut(&self.my_ip) { | ||||
|                 *my_info = NodeInfo { mratls_conns: my_info.mratls_conns + 1, ..my_info.clone() }; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn decrease_mratls_conns(&self) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|     pub async fn decrease_mratls_conns(&self) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         if let Some(my_info) = nodes.get_mut(&self.my_ip) { | ||||
|             if my_info.mratls_conns > 0 { | ||||
|                     *my_info = | ||||
|                         NodeInfo { mratls_conns: my_info.mratls_conns - 1, ..my_info.clone() }; | ||||
|                 } | ||||
|                 let mut updated_info = my_info.clone(); | ||||
|                 updated_info.mratls_conns -= 1; | ||||
|                 let _ = nodes.insert(self.my_ip.clone(), updated_info); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn increase_disk_attacks(&self) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|     pub async fn increase_disk_attacks(&self) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         if let Some(my_info) = nodes.get_mut(&self.my_ip) { | ||||
|                 *my_info = NodeInfo { disk_attacks: my_info.disk_attacks + 1, ..my_info.clone() }; | ||||
|             } | ||||
|             let mut updated_info = my_info.clone(); | ||||
|             updated_info.disk_attacks += 1; | ||||
|             let _ = nodes.insert(self.my_ip.clone(), updated_info); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn increase_net_attacks(&self) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|     pub async fn increase_net_attacks(&self) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         if let Some(my_info) = nodes.get_mut(&self.my_ip) { | ||||
|                 *my_info = NodeInfo { net_attacks: my_info.net_attacks + 1, ..my_info.clone() }; | ||||
|             } | ||||
|             let mut updated_info = my_info.clone(); | ||||
|             updated_info.net_attacks += 1; | ||||
|             let _ = nodes.insert(self.my_ip.clone(), updated_info); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn declare_myself_public(&self) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|     pub async fn declare_myself_public(&self) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         if let Some(my_info) = nodes.get_mut(&self.my_ip) { | ||||
|                 *my_info = NodeInfo { public: true, ..my_info.clone() }; | ||||
|             } | ||||
|             let mut updated_info = my_info.clone(); | ||||
|             updated_info.public = true; | ||||
|             let _ = nodes.insert(self.my_ip.clone(), updated_info); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_nodes(&self) -> Vec<(String, NodeInfo)> { | ||||
|         if let Ok(nodes) = self.nodes.read() { | ||||
|             return nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); | ||||
|         } | ||||
|         Vec::new() | ||||
|     pub async fn get_nodes(&self) -> Vec<(String, NodeInfo)> { | ||||
|         let nodes = self.nodes.read().await; | ||||
|         nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect() | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_my_ip(&self) -> String { | ||||
|     pub async fn get_my_ip(&self) -> String { | ||||
|         self.my_ip.clone() | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_my_info(&self) -> NodeInfo { | ||||
|         if let Ok(nodes) = self.nodes.read() { | ||||
|             if let Some(found_info) = nodes.get(&self.my_ip) { | ||||
|                 return found_info.clone(); | ||||
|             } | ||||
|         } | ||||
|         NodeInfo::new_empty() | ||||
|     pub async fn get_my_info(&self) -> NodeInfo { | ||||
|         let nodes = self.nodes.read().await; | ||||
|         nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty()) | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_connected_ips(&self) -> Vec<String> { | ||||
|         if let Ok(conns) = self.conns.read() { | ||||
|             return conns.iter().map(|n| n.clone()).collect(); | ||||
|         } | ||||
|         Vec::new() | ||||
|     pub async fn get_connected_ips(&self) -> Vec<String> { | ||||
|         let conns = self.conns.read().await; | ||||
|         conns.iter().map(|n| n.clone()).collect() | ||||
|     } | ||||
| 
 | ||||
|     // returns a random node that does not have an active connection
 | ||||
|     pub fn get_random_disconnected_ip(&self) -> Option<String> { | ||||
|     pub async fn get_random_disconnected_ip(&self) -> Option<String> { | ||||
|         use rand::{rngs::OsRng, RngCore}; | ||||
| 
 | ||||
|         let conn_ips = self.get_connected_ips(); | ||||
|         if let Ok(nodes) = self.nodes.read() { | ||||
|         let conn_ips = self.get_connected_ips().await; | ||||
|         let nodes = self.nodes.read().await; | ||||
|         let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len(); | ||||
|             return nodes | ||||
|         nodes | ||||
|             .keys() | ||||
|             .map(|ip| ip.to_string()) | ||||
|             .filter(|ip| ip != &self.my_ip && !conn_ips.contains(ip)) | ||||
|             .cycle() | ||||
|                 .nth(skip); | ||||
|         } | ||||
|         None | ||||
|             .nth(skip) | ||||
|     } | ||||
| 
 | ||||
|     /// This returns true if the update should be further forwarded
 | ||||
|     /// For example, we never forward our own updates that came back
 | ||||
|     pub fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool { | ||||
|     pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool { | ||||
|         let is_update_mine = node_ip.eq(&self.my_ip); | ||||
|         let mut is_update_new = false; | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|             is_update_new = nodes | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         let is_update_new = nodes | ||||
|             .get(&node_ip) | ||||
|             .map(|curr_info| node_info.is_newer_than(&curr_info)) | ||||
|             .unwrap_or(true); | ||||
|             if is_update_new { | ||||
|                 let _ = nodes.insert(node_ip.clone(), node_info.clone()); | ||||
|             } | ||||
|         } | ||||
|         if is_update_new { | ||||
|             println!("Inserting: {}, {}", node_ip, node_info.to_json()); | ||||
|             let _ = nodes.insert(node_ip, node_info); | ||||
|         } | ||||
|         is_update_new && !is_update_mine | ||||
|     } | ||||
| 
 | ||||
|     pub fn remove_inactive_nodes(&self, max_age: u64) { | ||||
|         if let Ok(mut nodes) = self.nodes.write() { | ||||
|     pub async fn remove_inactive_nodes(&self, max_age: u64) { | ||||
|         let mut nodes = self.nodes.write().await; | ||||
|         // TODO: Check if it is possible to corrupt SGX system time
 | ||||
|         let now = SystemTime::now(); | ||||
|         nodes.retain(|_, n| { | ||||
| @ -258,5 +248,4 @@ impl State { | ||||
|             age <= max_age | ||||
|         }); | ||||
|     } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -14,7 +14,7 @@ pub async fn grpc_new_conn( | ||||
|     ra_cfg: RaTlsConfig, | ||||
|     tx: Sender<InternalNodeUpdate>, | ||||
| ) { | ||||
|     if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip() { | ||||
|     if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip().await { | ||||
|         println!("IPv4 address is invalid: {node_ip}"); | ||||
|         return; | ||||
|     } | ||||
| @ -47,11 +47,11 @@ impl ConnManager { | ||||
| 
 | ||||
|     async fn connect_to(&self, node_ip: String) { | ||||
|         let state = self.state.clone(); | ||||
|         state.add_conn(&node_ip); | ||||
|         state.add_conn(&node_ip).await; | ||||
|         if let Err(e) = self.connect_to_int(node_ip.clone()).await { | ||||
|             println!("Client connection for {node_ip} failed: {e:?}"); | ||||
|         } | ||||
|         state.delete_conn(&node_ip); | ||||
|         state.delete_conn(&node_ip).await; | ||||
|     } | ||||
| 
 | ||||
|     async fn connect_to_int(&self, remote_ip: String) -> Result<(), Box<dyn std::error::Error>> { | ||||
| @ -104,20 +104,21 @@ impl ConnManager { | ||||
|                 } | ||||
|             }); | ||||
| 
 | ||||
|         let response = client.get_updates(rx_stream).await.map_err(|e| { | ||||
|         let response = client.get_updates(rx_stream).await; | ||||
|         if let Err(e) = response { | ||||
|             println!("Error connecting to {remote_ip}: {e}"); | ||||
|             if e.to_string().contains("QuoteVerifyError") { | ||||
|                 self.state.increase_net_attacks(); | ||||
|                 self.state.increase_net_attacks().await; | ||||
|             } | ||||
|             e | ||||
|         })?; | ||||
|         let mut resp_stream = response.into_inner(); | ||||
|             return Err(e.into()); | ||||
|         } | ||||
|         let mut resp_stream = response.unwrap().into_inner(); | ||||
| 
 | ||||
|         // Immediately send our info as a network update
 | ||||
|         let _ = self.tx.send((self.state.get_my_ip(), self.state.get_my_info()).into()); | ||||
|         let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into()); | ||||
|         while let Some(update) = resp_stream.message().await? { | ||||
|             // Update the entire network in case the information is new
 | ||||
|             if self.state.process_node_update(update.clone().into()) { | ||||
|             if self.state.process_node_update(update.clone().into()).await { | ||||
|                 // If process update returns true, the update must be forwarded
 | ||||
|                 if self.tx.send((remote_ip.clone(), update).into()).is_err() { | ||||
|                     println!("Tokio broadcast receivers had an issue consuming the channel"); | ||||
| @ -170,12 +171,13 @@ async fn query_keys( | ||||
|     let uri = Uri::from_static("https://example.com"); | ||||
|     let mut client = UpdateClient::with_origin(client, uri); | ||||
| 
 | ||||
|     let response = client.get_keys(tonic::Request::new(Empty {})).await.map_err(|e| { | ||||
|     let response = client.get_keys(tonic::Request::new(Empty {})).await; | ||||
|     if let Err(e) = response { | ||||
|         println!("Error getting keys from {node_ip}: {e}"); | ||||
|         if e.to_string().contains("QuoteVerifyError") { | ||||
|             state.increase_net_attacks(); | ||||
|             state.increase_net_attacks().await; | ||||
|         } | ||||
|         e | ||||
|     })?; | ||||
|     Ok(response.into_inner()) | ||||
|         return Err(e.into()); | ||||
|     } | ||||
|     Ok(response.unwrap().into_inner()) | ||||
| } | ||||
|  | ||||
| @ -92,7 +92,7 @@ impl NodeServer { | ||||
|                 let conn = if let Err(e) = conn { | ||||
|                     println!("Error accepting TLS connection: {e}"); | ||||
|                     if e.to_string().contains("HandshakeFailure") { | ||||
|                         state.increase_net_attacks(); | ||||
|                         state.increase_net_attacks().await; | ||||
|                     } | ||||
|                     return; | ||||
|                 } else { | ||||
| @ -131,12 +131,12 @@ struct ConnInfo { | ||||
| 
 | ||||
| #[tonic::async_trait] | ||||
| impl Update for NodeServer { | ||||
|     type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>; | ||||
| 
 | ||||
|     async fn get_keys(&self, _request: Request<Empty>) -> Result<Response<Keys>, Status> { | ||||
|         Ok(Response::new(self.keys.clone())) | ||||
|     } | ||||
| 
 | ||||
|     type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>; | ||||
| 
 | ||||
|     async fn get_updates( | ||||
|         &self, | ||||
|         req: Request<Streaming<NodeUpdate>>, | ||||
| @ -148,13 +148,13 @@ impl Update for NodeServer { | ||||
|         let mut rx = self.tx.subscribe(); | ||||
|         let mut inbound = req.into_inner(); | ||||
|         let state = self.state.clone(); | ||||
|         let my_ip = self.state.get_my_ip(); | ||||
|         let my_ip = self.state.get_my_ip().await; | ||||
| 
 | ||||
|         let stream = async_stream::stream! { | ||||
|             state.declare_myself_public(); | ||||
|             state.add_conn(&remote_ip); | ||||
|             state.declare_myself_public().await; | ||||
|             state.add_conn(&remote_ip).await; | ||||
| 
 | ||||
|             let known_nodes: Vec<NodeUpdate> = state.get_nodes().into_iter().map(Into::into).collect(); | ||||
|             let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect(); | ||||
|             for update in known_nodes { | ||||
|                 yield Ok(update); | ||||
|             } | ||||
| @ -173,7 +173,7 @@ impl Update for NodeServer { | ||||
|                                     println!("Node {remote_ip} is forwarding the update of {}", update.ip); | ||||
|                                 } | ||||
| 
 | ||||
|                                 if state.process_node_update(update.clone().into()) { | ||||
|                                 if state.process_node_update(update.clone().into()).await { | ||||
|                                     // if process update returns true, the update must be forwarded
 | ||||
|                                     if tx.send((remote_ip.clone(), update).into()).is_err() { | ||||
|                                         println!("Tokio broadcast receivers had an issue consuming the channel"); | ||||
| @ -199,7 +199,7 @@ impl Update for NodeServer { | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             state.delete_conn(&remote_ip); | ||||
|             state.delete_conn(&remote_ip).await; | ||||
|             yield Err(error_status); | ||||
|         }; | ||||
| 
 | ||||
|  | ||||
| @ -47,9 +47,10 @@ impl From<(String, datastore::NodeInfo)> for NodesResp { | ||||
| } | ||||
| 
 | ||||
| #[get("/nodes")] | ||||
| async fn get_nodes(ds: web::Data<Arc<State>>) -> HttpResponse { | ||||
| async fn get_nodes(state: web::Data<Arc<State>>) -> HttpResponse { | ||||
|     let nodes = state.get_nodes().await; | ||||
|     HttpResponse::Ok() | ||||
|         .json(ds.get_nodes().into_iter().map(Into::<NodesResp>::into).collect::<Vec<NodesResp>>()) | ||||
|         .json(nodes.into_iter().map(Into::<NodesResp>::into).collect::<Vec<NodesResp>>()) | ||||
| } | ||||
| 
 | ||||
| #[derive(Deserialize)] | ||||
| @ -64,13 +65,13 @@ async fn mint( | ||||
|     req: web::Json<MintReq>, | ||||
| ) -> impl Responder { | ||||
|     let recipient = req.into_inner().wallet; | ||||
|     state.increase_mint_requests(); | ||||
|     state.increase_mint_requests().await; | ||||
|     let result = | ||||
|         web::block(move || sol_client.mint(&recipient).map_err(|e| e.to_string())).await.unwrap(); // TODO: check if this can get a BlockingError
 | ||||
| 
 | ||||
|     match result { | ||||
|         Ok(s) => { | ||||
|             state.increase_mints(); | ||||
|             state.increase_mints().await; | ||||
|             HttpResponse::Ok().body(format!(r#"{{" signature": "{s} "}}"#)) | ||||
|         } | ||||
|         Err(e) => HttpResponse::InternalServerError().body(format!(r#"{{ "error": "{e}" }}"#)), | ||||
| @ -78,9 +79,9 @@ async fn mint( | ||||
| } | ||||
| 
 | ||||
| #[get("/metrics")] | ||||
| async fn metrics(ds: web::Data<Arc<State>>) -> HttpResponse { | ||||
| async fn metrics(state: web::Data<Arc<State>>) -> HttpResponse { | ||||
|     let mut metrics = String::new(); | ||||
|     for (ip, node) in ds.get_nodes() { | ||||
|     for (ip, node) in state.get_nodes().await { | ||||
|         metrics.push_str(node.to_metrics(&ip).as_str()); | ||||
|     } | ||||
|     HttpResponse::Ok().content_type("text/plain; version=0.0.4; charset=utf-8").body(metrics) | ||||
|  | ||||
							
								
								
									
										10
									
								
								src/main.rs
									
									
									
									
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										10
									
								
								src/main.rs
									
									
									
									
									
								
							| @ -51,12 +51,12 @@ pub async fn heartbeat( | ||||
|     loop { | ||||
|         interval.tick().await; | ||||
|         println!("Heartbeat..."); | ||||
|         state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3); | ||||
|         let connected_ips = state.get_connected_ips(); | ||||
|         state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await; | ||||
|         let connected_ips = state.get_connected_ips().await; | ||||
|         println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips); | ||||
|         let _ = tx.send((state.get_my_ip(), state.get_my_info()).into()); | ||||
|         let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into()); | ||||
|         if connected_ips.len() < NUM_CONNECTIONS { | ||||
|             if let Some(node_ip) = state.get_random_disconnected_ip() { | ||||
|             if let Some(node_ip) = state.get_random_disconnected_ip().await { | ||||
|                 println!("Dialing random node {}", node_ip); | ||||
|                 tasks.spawn(grpc_new_conn(node_ip, state.clone(), ra_cfg.clone(), tx.clone())); | ||||
|             } | ||||
| @ -77,7 +77,7 @@ async fn init_token(state: Arc<State>, ra_cfg: RaTlsConfig) -> SolClient { | ||||
|         } | ||||
|         Err(SealError::Attack(e)) => { | ||||
|             println!("The sealed keys are corrupted: {}", e); | ||||
|             state.increase_disk_attacks(); | ||||
|             state.increase_disk_attacks().await; | ||||
|         } | ||||
|         Err(e) => { | ||||
|             println!("Could not read sealed keys: {e}"); | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user