From eceafd9de8fe89fbdd49bcfed0553c22a4e5869d Mon Sep 17 00:00:00 2001 From: Valentyn Faychuk Date: Tue, 24 Dec 2024 03:01:01 +0200 Subject: [PATCH] switch to tokyo rwlock --- src/datastore.rs | 207 +++++++++++++++++++++------------------------ src/grpc/client.rs | 32 +++---- src/grpc/server.rs | 18 ++-- src/http_server.rs | 13 +-- src/main.rs | 10 +-- 5 files changed, 136 insertions(+), 144 deletions(-) diff --git a/src/datastore.rs b/src/datastore.rs index 178e633..9626b19 100644 --- a/src/datastore.rs +++ b/src/datastore.rs @@ -2,8 +2,8 @@ use crate::persistence::{SealError, SealedFile}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, TimestampSeconds}; use std::collections::{HashMap, HashSet}; -use std::sync::RwLock; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::RwLock; type IP = String; const LOCAL_INFO_FILE: &str = "/host/main/node_info"; @@ -104,159 +104,148 @@ impl State { Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } } - pub fn add_conn(&self, ip: &str) { - self.increase_mratls_conns(); - self.add_mratls_conn(ip); + pub async fn add_conn(&self, ip: &str) { + self.increase_mratls_conns().await; + self.add_mratls_conn(ip).await; } - pub fn delete_conn(&self, ip: &str) { - self.decrease_mratls_conns(); - self.delete_mratls_conn(ip); + pub async fn delete_conn(&self, ip: &str) { + self.decrease_mratls_conns().await; + self.delete_mratls_conn(ip).await; } - fn add_mratls_conn(&self, ip: &str) { - if let Ok(mut conns) = self.conns.write() { - conns.insert(ip.to_string()); + async fn add_mratls_conn(&self, ip: &str) { + let mut conns = self.conns.write().await; + conns.insert(ip.to_string()); + } + + async fn delete_mratls_conn(&self, ip: &str) { + let mut conns = self.conns.write().await; + conns.remove(ip); + } + + pub async fn increase_mint_requests(&self) { + let mut nodes = self.nodes.write().await; + if let Some(my_info) = nodes.get(&self.my_ip) { + let mut updated_info = my_info.clone(); + updated_info.mint_requests += 1; + let _ = nodes.insert(self.my_ip.clone(), updated_info); } } - fn delete_mratls_conn(&self, ip: &str) { - if let Ok(mut conns) = self.conns.write() { - conns.remove(ip); + pub async fn increase_mints(&self) { + let mut nodes = self.nodes.write().await; + if let Some(my_info) = nodes.get_mut(&self.my_ip) { + let mut updated_info = my_info.clone(); + updated_info.mints += 1; + let _ = nodes.insert(self.my_ip.clone(), updated_info); } } - pub fn increase_mint_requests(&self) { - if let Ok(mut nodes) = self.nodes.write() { - if let Some(my_info) = nodes.get_mut(&self.my_ip) { - *my_info = NodeInfo { mint_requests: my_info.mint_requests + 1, ..my_info.clone() }; + pub async fn increase_mratls_conns(&self) { + let mut nodes = self.nodes.write().await; + if let Some(my_info) = nodes.get_mut(&self.my_ip) { + let mut updated_info = my_info.clone(); + updated_info.mratls_conns += 1; + let _ = nodes.insert(self.my_ip.clone(), updated_info); + } + } + + pub async fn decrease_mratls_conns(&self) { + let mut nodes = self.nodes.write().await; + if let Some(my_info) = nodes.get_mut(&self.my_ip) { + if my_info.mratls_conns > 0 { + let mut updated_info = my_info.clone(); + updated_info.mratls_conns -= 1; + let _ = nodes.insert(self.my_ip.clone(), updated_info); } } } - pub fn increase_mints(&self) { - if let Ok(mut nodes) = self.nodes.write() { - if let Some(my_info) = nodes.get_mut(&self.my_ip) { - *my_info = NodeInfo { mints: my_info.mints + 1, ..my_info.clone() }; - } + pub async fn increase_disk_attacks(&self) { + let mut nodes = self.nodes.write().await; + if let Some(my_info) = nodes.get_mut(&self.my_ip) { + let mut updated_info = my_info.clone(); + updated_info.disk_attacks += 1; + let _ = nodes.insert(self.my_ip.clone(), updated_info); } } - pub fn increase_mratls_conns(&self) { - if let Ok(mut nodes) = self.nodes.write() { - if let Some(my_info) = nodes.get_mut(&self.my_ip) { - *my_info = NodeInfo { mratls_conns: my_info.mratls_conns + 1, ..my_info.clone() }; - } + pub async fn increase_net_attacks(&self) { + let mut nodes = self.nodes.write().await; + if let Some(my_info) = nodes.get_mut(&self.my_ip) { + let mut updated_info = my_info.clone(); + updated_info.net_attacks += 1; + let _ = nodes.insert(self.my_ip.clone(), updated_info); } } - pub fn decrease_mratls_conns(&self) { - if let Ok(mut nodes) = self.nodes.write() { - if let Some(my_info) = nodes.get_mut(&self.my_ip) { - if my_info.mratls_conns > 0 { - *my_info = - NodeInfo { mratls_conns: my_info.mratls_conns - 1, ..my_info.clone() }; - } - } + pub async fn declare_myself_public(&self) { + let mut nodes = self.nodes.write().await; + if let Some(my_info) = nodes.get_mut(&self.my_ip) { + let mut updated_info = my_info.clone(); + updated_info.public = true; + let _ = nodes.insert(self.my_ip.clone(), updated_info); } } - pub fn increase_disk_attacks(&self) { - if let Ok(mut nodes) = self.nodes.write() { - if let Some(my_info) = nodes.get_mut(&self.my_ip) { - *my_info = NodeInfo { disk_attacks: my_info.disk_attacks + 1, ..my_info.clone() }; - } - } + pub async fn get_nodes(&self) -> Vec<(String, NodeInfo)> { + let nodes = self.nodes.read().await; + nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect() } - pub fn increase_net_attacks(&self) { - if let Ok(mut nodes) = self.nodes.write() { - if let Some(my_info) = nodes.get_mut(&self.my_ip) { - *my_info = NodeInfo { net_attacks: my_info.net_attacks + 1, ..my_info.clone() }; - } - } - } - - pub fn declare_myself_public(&self) { - if let Ok(mut nodes) = self.nodes.write() { - if let Some(my_info) = nodes.get_mut(&self.my_ip) { - *my_info = NodeInfo { public: true, ..my_info.clone() }; - } - } - } - - pub fn get_nodes(&self) -> Vec<(String, NodeInfo)> { - if let Ok(nodes) = self.nodes.read() { - return nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); - } - Vec::new() - } - - pub fn get_my_ip(&self) -> String { + pub async fn get_my_ip(&self) -> String { self.my_ip.clone() } - pub fn get_my_info(&self) -> NodeInfo { - if let Ok(nodes) = self.nodes.read() { - if let Some(found_info) = nodes.get(&self.my_ip) { - return found_info.clone(); - } - } - NodeInfo::new_empty() + pub async fn get_my_info(&self) -> NodeInfo { + let nodes = self.nodes.read().await; + nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty()) } - pub fn get_connected_ips(&self) -> Vec { - if let Ok(conns) = self.conns.read() { - return conns.iter().map(|n| n.clone()).collect(); - } - Vec::new() + pub async fn get_connected_ips(&self) -> Vec { + let conns = self.conns.read().await; + conns.iter().map(|n| n.clone()).collect() } // returns a random node that does not have an active connection - pub fn get_random_disconnected_ip(&self) -> Option { + pub async fn get_random_disconnected_ip(&self) -> Option { use rand::{rngs::OsRng, RngCore}; - let conn_ips = self.get_connected_ips(); - if let Ok(nodes) = self.nodes.read() { - let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len(); - return nodes - .keys() - .map(|ip| ip.to_string()) - .filter(|ip| ip != &self.my_ip && !conn_ips.contains(ip)) - .cycle() - .nth(skip); - } - None + let conn_ips = self.get_connected_ips().await; + let nodes = self.nodes.read().await; + let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len(); + nodes + .keys() + .map(|ip| ip.to_string()) + .filter(|ip| ip != &self.my_ip && !conn_ips.contains(ip)) + .cycle() + .nth(skip) } /// This returns true if the update should be further forwarded /// For example, we never forward our own updates that came back - pub fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool { + pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool { let is_update_mine = node_ip.eq(&self.my_ip); - let mut is_update_new = false; - if let Ok(mut nodes) = self.nodes.write() { - is_update_new = nodes - .get(&node_ip) - .map(|curr_info| node_info.is_newer_than(&curr_info)) - .unwrap_or(true); - if is_update_new { - let _ = nodes.insert(node_ip.clone(), node_info.clone()); - } - } + let mut nodes = self.nodes.write().await; + let is_update_new = nodes + .get(&node_ip) + .map(|curr_info| node_info.is_newer_than(&curr_info)) + .unwrap_or(true); if is_update_new { println!("Inserting: {}, {}", node_ip, node_info.to_json()); + let _ = nodes.insert(node_ip, node_info); } is_update_new && !is_update_mine } - pub fn remove_inactive_nodes(&self, max_age: u64) { - if let Ok(mut nodes) = self.nodes.write() { - // TODO: Check if it is possible to corrupt SGX system time - let now = SystemTime::now(); - nodes.retain(|_, n| { - let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs(); - age <= max_age - }); - } + pub async fn remove_inactive_nodes(&self, max_age: u64) { + let mut nodes = self.nodes.write().await; + // TODO: Check if it is possible to corrupt SGX system time + let now = SystemTime::now(); + nodes.retain(|_, n| { + let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs(); + age <= max_age + }); } } diff --git a/src/grpc/client.rs b/src/grpc/client.rs index 5f62bbf..43c9ba0 100644 --- a/src/grpc/client.rs +++ b/src/grpc/client.rs @@ -14,7 +14,7 @@ pub async fn grpc_new_conn( ra_cfg: RaTlsConfig, tx: Sender, ) { - if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip() { + if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip().await { println!("IPv4 address is invalid: {node_ip}"); return; } @@ -47,11 +47,11 @@ impl ConnManager { async fn connect_to(&self, node_ip: String) { let state = self.state.clone(); - state.add_conn(&node_ip); + state.add_conn(&node_ip).await; if let Err(e) = self.connect_to_int(node_ip.clone()).await { println!("Client connection for {node_ip} failed: {e:?}"); } - state.delete_conn(&node_ip); + state.delete_conn(&node_ip).await; } async fn connect_to_int(&self, remote_ip: String) -> Result<(), Box> { @@ -104,20 +104,21 @@ impl ConnManager { } }); - let response = client.get_updates(rx_stream).await.map_err(|e| { + let response = client.get_updates(rx_stream).await; + if let Err(e) = response { println!("Error connecting to {remote_ip}: {e}"); if e.to_string().contains("QuoteVerifyError") { - self.state.increase_net_attacks(); + self.state.increase_net_attacks().await; } - e - })?; - let mut resp_stream = response.into_inner(); + return Err(e.into()); + } + let mut resp_stream = response.unwrap().into_inner(); // Immediately send our info as a network update - let _ = self.tx.send((self.state.get_my_ip(), self.state.get_my_info()).into()); + let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into()); while let Some(update) = resp_stream.message().await? { // Update the entire network in case the information is new - if self.state.process_node_update(update.clone().into()) { + if self.state.process_node_update(update.clone().into()).await { // If process update returns true, the update must be forwarded if self.tx.send((remote_ip.clone(), update).into()).is_err() { println!("Tokio broadcast receivers had an issue consuming the channel"); @@ -170,12 +171,13 @@ async fn query_keys( let uri = Uri::from_static("https://example.com"); let mut client = UpdateClient::with_origin(client, uri); - let response = client.get_keys(tonic::Request::new(Empty {})).await.map_err(|e| { + let response = client.get_keys(tonic::Request::new(Empty {})).await; + if let Err(e) = response { println!("Error getting keys from {node_ip}: {e}"); if e.to_string().contains("QuoteVerifyError") { - state.increase_net_attacks(); + state.increase_net_attacks().await; } - e - })?; - Ok(response.into_inner()) + return Err(e.into()); + } + Ok(response.unwrap().into_inner()) } diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 42c5a29..6746f0c 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -92,7 +92,7 @@ impl NodeServer { let conn = if let Err(e) = conn { println!("Error accepting TLS connection: {e}"); if e.to_string().contains("HandshakeFailure") { - state.increase_net_attacks(); + state.increase_net_attacks().await; } return; } else { @@ -131,12 +131,12 @@ struct ConnInfo { #[tonic::async_trait] impl Update for NodeServer { - type GetUpdatesStream = Pin> + Send>>; - async fn get_keys(&self, _request: Request) -> Result, Status> { Ok(Response::new(self.keys.clone())) } + type GetUpdatesStream = Pin> + Send>>; + async fn get_updates( &self, req: Request>, @@ -148,13 +148,13 @@ impl Update for NodeServer { let mut rx = self.tx.subscribe(); let mut inbound = req.into_inner(); let state = self.state.clone(); - let my_ip = self.state.get_my_ip(); + let my_ip = self.state.get_my_ip().await; let stream = async_stream::stream! { - state.declare_myself_public(); - state.add_conn(&remote_ip); + state.declare_myself_public().await; + state.add_conn(&remote_ip).await; - let known_nodes: Vec = state.get_nodes().into_iter().map(Into::into).collect(); + let known_nodes: Vec = state.get_nodes().await.into_iter().map(Into::into).collect(); for update in known_nodes { yield Ok(update); } @@ -173,7 +173,7 @@ impl Update for NodeServer { println!("Node {remote_ip} is forwarding the update of {}", update.ip); } - if state.process_node_update(update.clone().into()) { + if state.process_node_update(update.clone().into()).await { // if process update returns true, the update must be forwarded if tx.send((remote_ip.clone(), update).into()).is_err() { println!("Tokio broadcast receivers had an issue consuming the channel"); @@ -199,7 +199,7 @@ impl Update for NodeServer { } } } - state.delete_conn(&remote_ip); + state.delete_conn(&remote_ip).await; yield Err(error_status); }; diff --git a/src/http_server.rs b/src/http_server.rs index 00742e8..44e4caf 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -47,9 +47,10 @@ impl From<(String, datastore::NodeInfo)> for NodesResp { } #[get("/nodes")] -async fn get_nodes(ds: web::Data>) -> HttpResponse { +async fn get_nodes(state: web::Data>) -> HttpResponse { + let nodes = state.get_nodes().await; HttpResponse::Ok() - .json(ds.get_nodes().into_iter().map(Into::::into).collect::>()) + .json(nodes.into_iter().map(Into::::into).collect::>()) } #[derive(Deserialize)] @@ -64,13 +65,13 @@ async fn mint( req: web::Json, ) -> impl Responder { let recipient = req.into_inner().wallet; - state.increase_mint_requests(); + state.increase_mint_requests().await; let result = web::block(move || sol_client.mint(&recipient).map_err(|e| e.to_string())).await.unwrap(); // TODO: check if this can get a BlockingError match result { Ok(s) => { - state.increase_mints(); + state.increase_mints().await; HttpResponse::Ok().body(format!(r#"{{" signature": "{s} "}}"#)) } Err(e) => HttpResponse::InternalServerError().body(format!(r#"{{ "error": "{e}" }}"#)), @@ -78,9 +79,9 @@ async fn mint( } #[get("/metrics")] -async fn metrics(ds: web::Data>) -> HttpResponse { +async fn metrics(state: web::Data>) -> HttpResponse { let mut metrics = String::new(); - for (ip, node) in ds.get_nodes() { + for (ip, node) in state.get_nodes().await { metrics.push_str(node.to_metrics(&ip).as_str()); } HttpResponse::Ok().content_type("text/plain; version=0.0.4; charset=utf-8").body(metrics) diff --git a/src/main.rs b/src/main.rs index 5f6a27f..82b616c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,12 +51,12 @@ pub async fn heartbeat( loop { interval.tick().await; println!("Heartbeat..."); - state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3); - let connected_ips = state.get_connected_ips(); + state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await; + let connected_ips = state.get_connected_ips().await; println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips); - let _ = tx.send((state.get_my_ip(), state.get_my_info()).into()); + let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into()); if connected_ips.len() < NUM_CONNECTIONS { - if let Some(node_ip) = state.get_random_disconnected_ip() { + if let Some(node_ip) = state.get_random_disconnected_ip().await { println!("Dialing random node {}", node_ip); tasks.spawn(grpc_new_conn(node_ip, state.clone(), ra_cfg.clone(), tx.clone())); } @@ -77,7 +77,7 @@ async fn init_token(state: Arc, ra_cfg: RaTlsConfig) -> SolClient { } Err(SealError::Attack(e)) => { println!("The sealed keys are corrupted: {}", e); - state.increase_disk_attacks(); + state.increase_disk_attacks().await; } Err(e) => { println!("Could not read sealed keys: {e}");