switch to tokyo rwlock

This commit is contained in:
Valentyn Faychuk 2024-12-24 03:01:01 +02:00
parent e5e4109007
commit eceafd9de8
Signed by: valy
GPG Key ID: F1AB995E20FEADC5
5 changed files with 136 additions and 144 deletions

@ -2,8 +2,8 @@ use crate::persistence::{SealError, SealedFile};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds}; use serde_with::{serde_as, TimestampSeconds};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
type IP = String; type IP = String;
const LOCAL_INFO_FILE: &str = "/host/main/node_info"; const LOCAL_INFO_FILE: &str = "/host/main/node_info";
@ -104,153 +104,143 @@ impl State {
Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) } Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
} }
pub fn add_conn(&self, ip: &str) { pub async fn add_conn(&self, ip: &str) {
self.increase_mratls_conns(); self.increase_mratls_conns().await;
self.add_mratls_conn(ip); self.add_mratls_conn(ip).await;
} }
pub fn delete_conn(&self, ip: &str) { pub async fn delete_conn(&self, ip: &str) {
self.decrease_mratls_conns(); self.decrease_mratls_conns().await;
self.delete_mratls_conn(ip); self.delete_mratls_conn(ip).await;
} }
fn add_mratls_conn(&self, ip: &str) { async fn add_mratls_conn(&self, ip: &str) {
if let Ok(mut conns) = self.conns.write() { let mut conns = self.conns.write().await;
conns.insert(ip.to_string()); conns.insert(ip.to_string());
} }
}
fn delete_mratls_conn(&self, ip: &str) { async fn delete_mratls_conn(&self, ip: &str) {
if let Ok(mut conns) = self.conns.write() { let mut conns = self.conns.write().await;
conns.remove(ip); conns.remove(ip);
} }
pub async fn increase_mint_requests(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.mint_requests += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
} }
pub fn increase_mint_requests(&self) { pub async fn increase_mints(&self) {
if let Ok(mut nodes) = self.nodes.write() { let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) { if let Some(my_info) = nodes.get_mut(&self.my_ip) {
*my_info = NodeInfo { mint_requests: my_info.mint_requests + 1, ..my_info.clone() }; let mut updated_info = my_info.clone();
} updated_info.mints += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
} }
} }
pub fn increase_mints(&self) { pub async fn increase_mratls_conns(&self) {
if let Ok(mut nodes) = self.nodes.write() { let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) { if let Some(my_info) = nodes.get_mut(&self.my_ip) {
*my_info = NodeInfo { mints: my_info.mints + 1, ..my_info.clone() }; let mut updated_info = my_info.clone();
} updated_info.mratls_conns += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
} }
} }
pub fn increase_mratls_conns(&self) { pub async fn decrease_mratls_conns(&self) {
if let Ok(mut nodes) = self.nodes.write() { let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
*my_info = NodeInfo { mratls_conns: my_info.mratls_conns + 1, ..my_info.clone() };
}
}
}
pub fn decrease_mratls_conns(&self) {
if let Ok(mut nodes) = self.nodes.write() {
if let Some(my_info) = nodes.get_mut(&self.my_ip) { if let Some(my_info) = nodes.get_mut(&self.my_ip) {
if my_info.mratls_conns > 0 { if my_info.mratls_conns > 0 {
*my_info = let mut updated_info = my_info.clone();
NodeInfo { mratls_conns: my_info.mratls_conns - 1, ..my_info.clone() }; updated_info.mratls_conns -= 1;
} let _ = nodes.insert(self.my_ip.clone(), updated_info);
} }
} }
} }
pub fn increase_disk_attacks(&self) { pub async fn increase_disk_attacks(&self) {
if let Ok(mut nodes) = self.nodes.write() { let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) { if let Some(my_info) = nodes.get_mut(&self.my_ip) {
*my_info = NodeInfo { disk_attacks: my_info.disk_attacks + 1, ..my_info.clone() }; let mut updated_info = my_info.clone();
} updated_info.disk_attacks += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
} }
} }
pub fn increase_net_attacks(&self) { pub async fn increase_net_attacks(&self) {
if let Ok(mut nodes) = self.nodes.write() { let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) { if let Some(my_info) = nodes.get_mut(&self.my_ip) {
*my_info = NodeInfo { net_attacks: my_info.net_attacks + 1, ..my_info.clone() }; let mut updated_info = my_info.clone();
} updated_info.net_attacks += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
} }
} }
pub fn declare_myself_public(&self) { pub async fn declare_myself_public(&self) {
if let Ok(mut nodes) = self.nodes.write() { let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) { if let Some(my_info) = nodes.get_mut(&self.my_ip) {
*my_info = NodeInfo { public: true, ..my_info.clone() }; let mut updated_info = my_info.clone();
} updated_info.public = true;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
} }
} }
pub fn get_nodes(&self) -> Vec<(String, NodeInfo)> { pub async fn get_nodes(&self) -> Vec<(String, NodeInfo)> {
if let Ok(nodes) = self.nodes.read() { let nodes = self.nodes.read().await;
return nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}
Vec::new()
} }
pub fn get_my_ip(&self) -> String { pub async fn get_my_ip(&self) -> String {
self.my_ip.clone() self.my_ip.clone()
} }
pub fn get_my_info(&self) -> NodeInfo { pub async fn get_my_info(&self) -> NodeInfo {
if let Ok(nodes) = self.nodes.read() { let nodes = self.nodes.read().await;
if let Some(found_info) = nodes.get(&self.my_ip) { nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty())
return found_info.clone();
}
}
NodeInfo::new_empty()
} }
pub fn get_connected_ips(&self) -> Vec<String> { pub async fn get_connected_ips(&self) -> Vec<String> {
if let Ok(conns) = self.conns.read() { let conns = self.conns.read().await;
return conns.iter().map(|n| n.clone()).collect(); conns.iter().map(|n| n.clone()).collect()
}
Vec::new()
} }
// returns a random node that does not have an active connection // returns a random node that does not have an active connection
pub fn get_random_disconnected_ip(&self) -> Option<String> { pub async fn get_random_disconnected_ip(&self) -> Option<String> {
use rand::{rngs::OsRng, RngCore}; use rand::{rngs::OsRng, RngCore};
let conn_ips = self.get_connected_ips(); let conn_ips = self.get_connected_ips().await;
if let Ok(nodes) = self.nodes.read() { let nodes = self.nodes.read().await;
let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len(); let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len();
return nodes nodes
.keys() .keys()
.map(|ip| ip.to_string()) .map(|ip| ip.to_string())
.filter(|ip| ip != &self.my_ip && !conn_ips.contains(ip)) .filter(|ip| ip != &self.my_ip && !conn_ips.contains(ip))
.cycle() .cycle()
.nth(skip); .nth(skip)
}
None
} }
/// This returns true if the update should be further forwarded /// This returns true if the update should be further forwarded
/// For example, we never forward our own updates that came back /// For example, we never forward our own updates that came back
pub fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool { pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool {
let is_update_mine = node_ip.eq(&self.my_ip); let is_update_mine = node_ip.eq(&self.my_ip);
let mut is_update_new = false; let mut nodes = self.nodes.write().await;
if let Ok(mut nodes) = self.nodes.write() { let is_update_new = nodes
is_update_new = nodes
.get(&node_ip) .get(&node_ip)
.map(|curr_info| node_info.is_newer_than(&curr_info)) .map(|curr_info| node_info.is_newer_than(&curr_info))
.unwrap_or(true); .unwrap_or(true);
if is_update_new {
let _ = nodes.insert(node_ip.clone(), node_info.clone());
}
}
if is_update_new { if is_update_new {
println!("Inserting: {}, {}", node_ip, node_info.to_json()); println!("Inserting: {}, {}", node_ip, node_info.to_json());
let _ = nodes.insert(node_ip, node_info);
} }
is_update_new && !is_update_mine is_update_new && !is_update_mine
} }
pub fn remove_inactive_nodes(&self, max_age: u64) { pub async fn remove_inactive_nodes(&self, max_age: u64) {
if let Ok(mut nodes) = self.nodes.write() { let mut nodes = self.nodes.write().await;
// TODO: Check if it is possible to corrupt SGX system time // TODO: Check if it is possible to corrupt SGX system time
let now = SystemTime::now(); let now = SystemTime::now();
nodes.retain(|_, n| { nodes.retain(|_, n| {
@ -259,4 +249,3 @@ impl State {
}); });
} }
} }
}

@ -14,7 +14,7 @@ pub async fn grpc_new_conn(
ra_cfg: RaTlsConfig, ra_cfg: RaTlsConfig,
tx: Sender<InternalNodeUpdate>, tx: Sender<InternalNodeUpdate>,
) { ) {
if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip() { if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip().await {
println!("IPv4 address is invalid: {node_ip}"); println!("IPv4 address is invalid: {node_ip}");
return; return;
} }
@ -47,11 +47,11 @@ impl ConnManager {
async fn connect_to(&self, node_ip: String) { async fn connect_to(&self, node_ip: String) {
let state = self.state.clone(); let state = self.state.clone();
state.add_conn(&node_ip); state.add_conn(&node_ip).await;
if let Err(e) = self.connect_to_int(node_ip.clone()).await { if let Err(e) = self.connect_to_int(node_ip.clone()).await {
println!("Client connection for {node_ip} failed: {e:?}"); println!("Client connection for {node_ip} failed: {e:?}");
} }
state.delete_conn(&node_ip); state.delete_conn(&node_ip).await;
} }
async fn connect_to_int(&self, remote_ip: String) -> Result<(), Box<dyn std::error::Error>> { async fn connect_to_int(&self, remote_ip: String) -> Result<(), Box<dyn std::error::Error>> {
@ -104,20 +104,21 @@ impl ConnManager {
} }
}); });
let response = client.get_updates(rx_stream).await.map_err(|e| { let response = client.get_updates(rx_stream).await;
if let Err(e) = response {
println!("Error connecting to {remote_ip}: {e}"); println!("Error connecting to {remote_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") { if e.to_string().contains("QuoteVerifyError") {
self.state.increase_net_attacks(); self.state.increase_net_attacks().await;
} }
e return Err(e.into());
})?; }
let mut resp_stream = response.into_inner(); let mut resp_stream = response.unwrap().into_inner();
// Immediately send our info as a network update // Immediately send our info as a network update
let _ = self.tx.send((self.state.get_my_ip(), self.state.get_my_info()).into()); let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
while let Some(update) = resp_stream.message().await? { while let Some(update) = resp_stream.message().await? {
// Update the entire network in case the information is new // Update the entire network in case the information is new
if self.state.process_node_update(update.clone().into()) { if self.state.process_node_update(update.clone().into()).await {
// If process update returns true, the update must be forwarded // If process update returns true, the update must be forwarded
if self.tx.send((remote_ip.clone(), update).into()).is_err() { if self.tx.send((remote_ip.clone(), update).into()).is_err() {
println!("Tokio broadcast receivers had an issue consuming the channel"); println!("Tokio broadcast receivers had an issue consuming the channel");
@ -170,12 +171,13 @@ async fn query_keys(
let uri = Uri::from_static("https://example.com"); let uri = Uri::from_static("https://example.com");
let mut client = UpdateClient::with_origin(client, uri); let mut client = UpdateClient::with_origin(client, uri);
let response = client.get_keys(tonic::Request::new(Empty {})).await.map_err(|e| { let response = client.get_keys(tonic::Request::new(Empty {})).await;
if let Err(e) = response {
println!("Error getting keys from {node_ip}: {e}"); println!("Error getting keys from {node_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") { if e.to_string().contains("QuoteVerifyError") {
state.increase_net_attacks(); state.increase_net_attacks().await;
} }
e return Err(e.into());
})?; }
Ok(response.into_inner()) Ok(response.unwrap().into_inner())
} }

@ -92,7 +92,7 @@ impl NodeServer {
let conn = if let Err(e) = conn { let conn = if let Err(e) = conn {
println!("Error accepting TLS connection: {e}"); println!("Error accepting TLS connection: {e}");
if e.to_string().contains("HandshakeFailure") { if e.to_string().contains("HandshakeFailure") {
state.increase_net_attacks(); state.increase_net_attacks().await;
} }
return; return;
} else { } else {
@ -131,12 +131,12 @@ struct ConnInfo {
#[tonic::async_trait] #[tonic::async_trait]
impl Update for NodeServer { impl Update for NodeServer {
type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>;
async fn get_keys(&self, _request: Request<Empty>) -> Result<Response<Keys>, Status> { async fn get_keys(&self, _request: Request<Empty>) -> Result<Response<Keys>, Status> {
Ok(Response::new(self.keys.clone())) Ok(Response::new(self.keys.clone()))
} }
type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>;
async fn get_updates( async fn get_updates(
&self, &self,
req: Request<Streaming<NodeUpdate>>, req: Request<Streaming<NodeUpdate>>,
@ -148,13 +148,13 @@ impl Update for NodeServer {
let mut rx = self.tx.subscribe(); let mut rx = self.tx.subscribe();
let mut inbound = req.into_inner(); let mut inbound = req.into_inner();
let state = self.state.clone(); let state = self.state.clone();
let my_ip = self.state.get_my_ip(); let my_ip = self.state.get_my_ip().await;
let stream = async_stream::stream! { let stream = async_stream::stream! {
state.declare_myself_public(); state.declare_myself_public().await;
state.add_conn(&remote_ip); state.add_conn(&remote_ip).await;
let known_nodes: Vec<NodeUpdate> = state.get_nodes().into_iter().map(Into::into).collect(); let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect();
for update in known_nodes { for update in known_nodes {
yield Ok(update); yield Ok(update);
} }
@ -173,7 +173,7 @@ impl Update for NodeServer {
println!("Node {remote_ip} is forwarding the update of {}", update.ip); println!("Node {remote_ip} is forwarding the update of {}", update.ip);
} }
if state.process_node_update(update.clone().into()) { if state.process_node_update(update.clone().into()).await {
// if process update returns true, the update must be forwarded // if process update returns true, the update must be forwarded
if tx.send((remote_ip.clone(), update).into()).is_err() { if tx.send((remote_ip.clone(), update).into()).is_err() {
println!("Tokio broadcast receivers had an issue consuming the channel"); println!("Tokio broadcast receivers had an issue consuming the channel");
@ -199,7 +199,7 @@ impl Update for NodeServer {
} }
} }
} }
state.delete_conn(&remote_ip); state.delete_conn(&remote_ip).await;
yield Err(error_status); yield Err(error_status);
}; };

@ -47,9 +47,10 @@ impl From<(String, datastore::NodeInfo)> for NodesResp {
} }
#[get("/nodes")] #[get("/nodes")]
async fn get_nodes(ds: web::Data<Arc<State>>) -> HttpResponse { async fn get_nodes(state: web::Data<Arc<State>>) -> HttpResponse {
let nodes = state.get_nodes().await;
HttpResponse::Ok() HttpResponse::Ok()
.json(ds.get_nodes().into_iter().map(Into::<NodesResp>::into).collect::<Vec<NodesResp>>()) .json(nodes.into_iter().map(Into::<NodesResp>::into).collect::<Vec<NodesResp>>())
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -64,13 +65,13 @@ async fn mint(
req: web::Json<MintReq>, req: web::Json<MintReq>,
) -> impl Responder { ) -> impl Responder {
let recipient = req.into_inner().wallet; let recipient = req.into_inner().wallet;
state.increase_mint_requests(); state.increase_mint_requests().await;
let result = let result =
web::block(move || sol_client.mint(&recipient).map_err(|e| e.to_string())).await.unwrap(); // TODO: check if this can get a BlockingError web::block(move || sol_client.mint(&recipient).map_err(|e| e.to_string())).await.unwrap(); // TODO: check if this can get a BlockingError
match result { match result {
Ok(s) => { Ok(s) => {
state.increase_mints(); state.increase_mints().await;
HttpResponse::Ok().body(format!(r#"{{" signature": "{s} "}}"#)) HttpResponse::Ok().body(format!(r#"{{" signature": "{s} "}}"#))
} }
Err(e) => HttpResponse::InternalServerError().body(format!(r#"{{ "error": "{e}" }}"#)), Err(e) => HttpResponse::InternalServerError().body(format!(r#"{{ "error": "{e}" }}"#)),
@ -78,9 +79,9 @@ async fn mint(
} }
#[get("/metrics")] #[get("/metrics")]
async fn metrics(ds: web::Data<Arc<State>>) -> HttpResponse { async fn metrics(state: web::Data<Arc<State>>) -> HttpResponse {
let mut metrics = String::new(); let mut metrics = String::new();
for (ip, node) in ds.get_nodes() { for (ip, node) in state.get_nodes().await {
metrics.push_str(node.to_metrics(&ip).as_str()); metrics.push_str(node.to_metrics(&ip).as_str());
} }
HttpResponse::Ok().content_type("text/plain; version=0.0.4; charset=utf-8").body(metrics) HttpResponse::Ok().content_type("text/plain; version=0.0.4; charset=utf-8").body(metrics)

@ -51,12 +51,12 @@ pub async fn heartbeat(
loop { loop {
interval.tick().await; interval.tick().await;
println!("Heartbeat..."); println!("Heartbeat...");
state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3); state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await;
let connected_ips = state.get_connected_ips(); let connected_ips = state.get_connected_ips().await;
println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips); println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips);
let _ = tx.send((state.get_my_ip(), state.get_my_info()).into()); let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into());
if connected_ips.len() < NUM_CONNECTIONS { if connected_ips.len() < NUM_CONNECTIONS {
if let Some(node_ip) = state.get_random_disconnected_ip() { if let Some(node_ip) = state.get_random_disconnected_ip().await {
println!("Dialing random node {}", node_ip); println!("Dialing random node {}", node_ip);
tasks.spawn(grpc_new_conn(node_ip, state.clone(), ra_cfg.clone(), tx.clone())); tasks.spawn(grpc_new_conn(node_ip, state.clone(), ra_cfg.clone(), tx.clone()));
} }
@ -77,7 +77,7 @@ async fn init_token(state: Arc<State>, ra_cfg: RaTlsConfig) -> SolClient {
} }
Err(SealError::Attack(e)) => { Err(SealError::Attack(e)) => {
println!("The sealed keys are corrupted: {}", e); println!("The sealed keys are corrupted: {}", e);
state.increase_disk_attacks(); state.increase_disk_attacks().await;
} }
Err(e) => { Err(e) => {
println!("Could not read sealed keys: {e}"); println!("Could not read sealed keys: {e}");