switch to tokyo rwlock
This commit is contained in:
parent
e5e4109007
commit
eceafd9de8
151
src/datastore.rs
151
src/datastore.rs
@ -2,8 +2,8 @@ use crate::persistence::{SealError, SealedFile};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::RwLock;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
type IP = String;
|
||||
const LOCAL_INFO_FILE: &str = "/host/main/node_info";
|
||||
@ -104,153 +104,143 @@ impl State {
|
||||
Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
|
||||
}
|
||||
|
||||
pub fn add_conn(&self, ip: &str) {
|
||||
self.increase_mratls_conns();
|
||||
self.add_mratls_conn(ip);
|
||||
pub async fn add_conn(&self, ip: &str) {
|
||||
self.increase_mratls_conns().await;
|
||||
self.add_mratls_conn(ip).await;
|
||||
}
|
||||
|
||||
pub fn delete_conn(&self, ip: &str) {
|
||||
self.decrease_mratls_conns();
|
||||
self.delete_mratls_conn(ip);
|
||||
pub async fn delete_conn(&self, ip: &str) {
|
||||
self.decrease_mratls_conns().await;
|
||||
self.delete_mratls_conn(ip).await;
|
||||
}
|
||||
|
||||
fn add_mratls_conn(&self, ip: &str) {
|
||||
if let Ok(mut conns) = self.conns.write() {
|
||||
async fn add_mratls_conn(&self, ip: &str) {
|
||||
let mut conns = self.conns.write().await;
|
||||
conns.insert(ip.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
fn delete_mratls_conn(&self, ip: &str) {
|
||||
if let Ok(mut conns) = self.conns.write() {
|
||||
async fn delete_mratls_conn(&self, ip: &str) {
|
||||
let mut conns = self.conns.write().await;
|
||||
conns.remove(ip);
|
||||
}
|
||||
|
||||
pub async fn increase_mint_requests(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
if let Some(my_info) = nodes.get(&self.my_ip) {
|
||||
let mut updated_info = my_info.clone();
|
||||
updated_info.mint_requests += 1;
|
||||
let _ = nodes.insert(self.my_ip.clone(), updated_info);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increase_mint_requests(&self) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
pub async fn increase_mints(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
|
||||
*my_info = NodeInfo { mint_requests: my_info.mint_requests + 1, ..my_info.clone() };
|
||||
}
|
||||
let mut updated_info = my_info.clone();
|
||||
updated_info.mints += 1;
|
||||
let _ = nodes.insert(self.my_ip.clone(), updated_info);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increase_mints(&self) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
pub async fn increase_mratls_conns(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
|
||||
*my_info = NodeInfo { mints: my_info.mints + 1, ..my_info.clone() };
|
||||
}
|
||||
let mut updated_info = my_info.clone();
|
||||
updated_info.mratls_conns += 1;
|
||||
let _ = nodes.insert(self.my_ip.clone(), updated_info);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increase_mratls_conns(&self) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
|
||||
*my_info = NodeInfo { mratls_conns: my_info.mratls_conns + 1, ..my_info.clone() };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrease_mratls_conns(&self) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
pub async fn decrease_mratls_conns(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
|
||||
if my_info.mratls_conns > 0 {
|
||||
*my_info =
|
||||
NodeInfo { mratls_conns: my_info.mratls_conns - 1, ..my_info.clone() };
|
||||
}
|
||||
let mut updated_info = my_info.clone();
|
||||
updated_info.mratls_conns -= 1;
|
||||
let _ = nodes.insert(self.my_ip.clone(), updated_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increase_disk_attacks(&self) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
pub async fn increase_disk_attacks(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
|
||||
*my_info = NodeInfo { disk_attacks: my_info.disk_attacks + 1, ..my_info.clone() };
|
||||
}
|
||||
let mut updated_info = my_info.clone();
|
||||
updated_info.disk_attacks += 1;
|
||||
let _ = nodes.insert(self.my_ip.clone(), updated_info);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increase_net_attacks(&self) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
pub async fn increase_net_attacks(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
|
||||
*my_info = NodeInfo { net_attacks: my_info.net_attacks + 1, ..my_info.clone() };
|
||||
}
|
||||
let mut updated_info = my_info.clone();
|
||||
updated_info.net_attacks += 1;
|
||||
let _ = nodes.insert(self.my_ip.clone(), updated_info);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn declare_myself_public(&self) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
pub async fn declare_myself_public(&self) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
|
||||
*my_info = NodeInfo { public: true, ..my_info.clone() };
|
||||
}
|
||||
let mut updated_info = my_info.clone();
|
||||
updated_info.public = true;
|
||||
let _ = nodes.insert(self.my_ip.clone(), updated_info);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_nodes(&self) -> Vec<(String, NodeInfo)> {
|
||||
if let Ok(nodes) = self.nodes.read() {
|
||||
return nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
|
||||
}
|
||||
Vec::new()
|
||||
pub async fn get_nodes(&self) -> Vec<(String, NodeInfo)> {
|
||||
let nodes = self.nodes.read().await;
|
||||
nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
|
||||
}
|
||||
|
||||
pub fn get_my_ip(&self) -> String {
|
||||
pub async fn get_my_ip(&self) -> String {
|
||||
self.my_ip.clone()
|
||||
}
|
||||
|
||||
pub fn get_my_info(&self) -> NodeInfo {
|
||||
if let Ok(nodes) = self.nodes.read() {
|
||||
if let Some(found_info) = nodes.get(&self.my_ip) {
|
||||
return found_info.clone();
|
||||
}
|
||||
}
|
||||
NodeInfo::new_empty()
|
||||
pub async fn get_my_info(&self) -> NodeInfo {
|
||||
let nodes = self.nodes.read().await;
|
||||
nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty())
|
||||
}
|
||||
|
||||
pub fn get_connected_ips(&self) -> Vec<String> {
|
||||
if let Ok(conns) = self.conns.read() {
|
||||
return conns.iter().map(|n| n.clone()).collect();
|
||||
}
|
||||
Vec::new()
|
||||
pub async fn get_connected_ips(&self) -> Vec<String> {
|
||||
let conns = self.conns.read().await;
|
||||
conns.iter().map(|n| n.clone()).collect()
|
||||
}
|
||||
|
||||
// returns a random node that does not have an active connection
|
||||
pub fn get_random_disconnected_ip(&self) -> Option<String> {
|
||||
pub async fn get_random_disconnected_ip(&self) -> Option<String> {
|
||||
use rand::{rngs::OsRng, RngCore};
|
||||
|
||||
let conn_ips = self.get_connected_ips();
|
||||
if let Ok(nodes) = self.nodes.read() {
|
||||
let conn_ips = self.get_connected_ips().await;
|
||||
let nodes = self.nodes.read().await;
|
||||
let skip = OsRng.next_u64().try_into().unwrap_or(0) % nodes.len();
|
||||
return nodes
|
||||
nodes
|
||||
.keys()
|
||||
.map(|ip| ip.to_string())
|
||||
.filter(|ip| ip != &self.my_ip && !conn_ips.contains(ip))
|
||||
.cycle()
|
||||
.nth(skip);
|
||||
}
|
||||
None
|
||||
.nth(skip)
|
||||
}
|
||||
|
||||
/// This returns true if the update should be further forwarded
|
||||
/// For example, we never forward our own updates that came back
|
||||
pub fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool {
|
||||
pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool {
|
||||
let is_update_mine = node_ip.eq(&self.my_ip);
|
||||
let mut is_update_new = false;
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
is_update_new = nodes
|
||||
let mut nodes = self.nodes.write().await;
|
||||
let is_update_new = nodes
|
||||
.get(&node_ip)
|
||||
.map(|curr_info| node_info.is_newer_than(&curr_info))
|
||||
.unwrap_or(true);
|
||||
if is_update_new {
|
||||
let _ = nodes.insert(node_ip.clone(), node_info.clone());
|
||||
}
|
||||
}
|
||||
if is_update_new {
|
||||
println!("Inserting: {}, {}", node_ip, node_info.to_json());
|
||||
let _ = nodes.insert(node_ip, node_info);
|
||||
}
|
||||
is_update_new && !is_update_mine
|
||||
}
|
||||
|
||||
pub fn remove_inactive_nodes(&self, max_age: u64) {
|
||||
if let Ok(mut nodes) = self.nodes.write() {
|
||||
pub async fn remove_inactive_nodes(&self, max_age: u64) {
|
||||
let mut nodes = self.nodes.write().await;
|
||||
// TODO: Check if it is possible to corrupt SGX system time
|
||||
let now = SystemTime::now();
|
||||
nodes.retain(|_, n| {
|
||||
@ -259,4 +249,3 @@ impl State {
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ pub async fn grpc_new_conn(
|
||||
ra_cfg: RaTlsConfig,
|
||||
tx: Sender<InternalNodeUpdate>,
|
||||
) {
|
||||
if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip() {
|
||||
if Ipv4Addr::from_str(&node_ip).is_err() || node_ip == state.get_my_ip().await {
|
||||
println!("IPv4 address is invalid: {node_ip}");
|
||||
return;
|
||||
}
|
||||
@ -47,11 +47,11 @@ impl ConnManager {
|
||||
|
||||
async fn connect_to(&self, node_ip: String) {
|
||||
let state = self.state.clone();
|
||||
state.add_conn(&node_ip);
|
||||
state.add_conn(&node_ip).await;
|
||||
if let Err(e) = self.connect_to_int(node_ip.clone()).await {
|
||||
println!("Client connection for {node_ip} failed: {e:?}");
|
||||
}
|
||||
state.delete_conn(&node_ip);
|
||||
state.delete_conn(&node_ip).await;
|
||||
}
|
||||
|
||||
async fn connect_to_int(&self, remote_ip: String) -> Result<(), Box<dyn std::error::Error>> {
|
||||
@ -104,20 +104,21 @@ impl ConnManager {
|
||||
}
|
||||
});
|
||||
|
||||
let response = client.get_updates(rx_stream).await.map_err(|e| {
|
||||
let response = client.get_updates(rx_stream).await;
|
||||
if let Err(e) = response {
|
||||
println!("Error connecting to {remote_ip}: {e}");
|
||||
if e.to_string().contains("QuoteVerifyError") {
|
||||
self.state.increase_net_attacks();
|
||||
self.state.increase_net_attacks().await;
|
||||
}
|
||||
e
|
||||
})?;
|
||||
let mut resp_stream = response.into_inner();
|
||||
return Err(e.into());
|
||||
}
|
||||
let mut resp_stream = response.unwrap().into_inner();
|
||||
|
||||
// Immediately send our info as a network update
|
||||
let _ = self.tx.send((self.state.get_my_ip(), self.state.get_my_info()).into());
|
||||
let _ = self.tx.send((self.state.get_my_ip().await, self.state.get_my_info().await).into());
|
||||
while let Some(update) = resp_stream.message().await? {
|
||||
// Update the entire network in case the information is new
|
||||
if self.state.process_node_update(update.clone().into()) {
|
||||
if self.state.process_node_update(update.clone().into()).await {
|
||||
// If process update returns true, the update must be forwarded
|
||||
if self.tx.send((remote_ip.clone(), update).into()).is_err() {
|
||||
println!("Tokio broadcast receivers had an issue consuming the channel");
|
||||
@ -170,12 +171,13 @@ async fn query_keys(
|
||||
let uri = Uri::from_static("https://example.com");
|
||||
let mut client = UpdateClient::with_origin(client, uri);
|
||||
|
||||
let response = client.get_keys(tonic::Request::new(Empty {})).await.map_err(|e| {
|
||||
let response = client.get_keys(tonic::Request::new(Empty {})).await;
|
||||
if let Err(e) = response {
|
||||
println!("Error getting keys from {node_ip}: {e}");
|
||||
if e.to_string().contains("QuoteVerifyError") {
|
||||
state.increase_net_attacks();
|
||||
state.increase_net_attacks().await;
|
||||
}
|
||||
e
|
||||
})?;
|
||||
Ok(response.into_inner())
|
||||
return Err(e.into());
|
||||
}
|
||||
Ok(response.unwrap().into_inner())
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ impl NodeServer {
|
||||
let conn = if let Err(e) = conn {
|
||||
println!("Error accepting TLS connection: {e}");
|
||||
if e.to_string().contains("HandshakeFailure") {
|
||||
state.increase_net_attacks();
|
||||
state.increase_net_attacks().await;
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
@ -131,12 +131,12 @@ struct ConnInfo {
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl Update for NodeServer {
|
||||
type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>;
|
||||
|
||||
async fn get_keys(&self, _request: Request<Empty>) -> Result<Response<Keys>, Status> {
|
||||
Ok(Response::new(self.keys.clone()))
|
||||
}
|
||||
|
||||
type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>;
|
||||
|
||||
async fn get_updates(
|
||||
&self,
|
||||
req: Request<Streaming<NodeUpdate>>,
|
||||
@ -148,13 +148,13 @@ impl Update for NodeServer {
|
||||
let mut rx = self.tx.subscribe();
|
||||
let mut inbound = req.into_inner();
|
||||
let state = self.state.clone();
|
||||
let my_ip = self.state.get_my_ip();
|
||||
let my_ip = self.state.get_my_ip().await;
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
state.declare_myself_public();
|
||||
state.add_conn(&remote_ip);
|
||||
state.declare_myself_public().await;
|
||||
state.add_conn(&remote_ip).await;
|
||||
|
||||
let known_nodes: Vec<NodeUpdate> = state.get_nodes().into_iter().map(Into::into).collect();
|
||||
let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect();
|
||||
for update in known_nodes {
|
||||
yield Ok(update);
|
||||
}
|
||||
@ -173,7 +173,7 @@ impl Update for NodeServer {
|
||||
println!("Node {remote_ip} is forwarding the update of {}", update.ip);
|
||||
}
|
||||
|
||||
if state.process_node_update(update.clone().into()) {
|
||||
if state.process_node_update(update.clone().into()).await {
|
||||
// if process update returns true, the update must be forwarded
|
||||
if tx.send((remote_ip.clone(), update).into()).is_err() {
|
||||
println!("Tokio broadcast receivers had an issue consuming the channel");
|
||||
@ -199,7 +199,7 @@ impl Update for NodeServer {
|
||||
}
|
||||
}
|
||||
}
|
||||
state.delete_conn(&remote_ip);
|
||||
state.delete_conn(&remote_ip).await;
|
||||
yield Err(error_status);
|
||||
};
|
||||
|
||||
|
@ -47,9 +47,10 @@ impl From<(String, datastore::NodeInfo)> for NodesResp {
|
||||
}
|
||||
|
||||
#[get("/nodes")]
|
||||
async fn get_nodes(ds: web::Data<Arc<State>>) -> HttpResponse {
|
||||
async fn get_nodes(state: web::Data<Arc<State>>) -> HttpResponse {
|
||||
let nodes = state.get_nodes().await;
|
||||
HttpResponse::Ok()
|
||||
.json(ds.get_nodes().into_iter().map(Into::<NodesResp>::into).collect::<Vec<NodesResp>>())
|
||||
.json(nodes.into_iter().map(Into::<NodesResp>::into).collect::<Vec<NodesResp>>())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -64,13 +65,13 @@ async fn mint(
|
||||
req: web::Json<MintReq>,
|
||||
) -> impl Responder {
|
||||
let recipient = req.into_inner().wallet;
|
||||
state.increase_mint_requests();
|
||||
state.increase_mint_requests().await;
|
||||
let result =
|
||||
web::block(move || sol_client.mint(&recipient).map_err(|e| e.to_string())).await.unwrap(); // TODO: check if this can get a BlockingError
|
||||
|
||||
match result {
|
||||
Ok(s) => {
|
||||
state.increase_mints();
|
||||
state.increase_mints().await;
|
||||
HttpResponse::Ok().body(format!(r#"{{" signature": "{s} "}}"#))
|
||||
}
|
||||
Err(e) => HttpResponse::InternalServerError().body(format!(r#"{{ "error": "{e}" }}"#)),
|
||||
@ -78,9 +79,9 @@ async fn mint(
|
||||
}
|
||||
|
||||
#[get("/metrics")]
|
||||
async fn metrics(ds: web::Data<Arc<State>>) -> HttpResponse {
|
||||
async fn metrics(state: web::Data<Arc<State>>) -> HttpResponse {
|
||||
let mut metrics = String::new();
|
||||
for (ip, node) in ds.get_nodes() {
|
||||
for (ip, node) in state.get_nodes().await {
|
||||
metrics.push_str(node.to_metrics(&ip).as_str());
|
||||
}
|
||||
HttpResponse::Ok().content_type("text/plain; version=0.0.4; charset=utf-8").body(metrics)
|
||||
|
10
src/main.rs
10
src/main.rs
@ -51,12 +51,12 @@ pub async fn heartbeat(
|
||||
loop {
|
||||
interval.tick().await;
|
||||
println!("Heartbeat...");
|
||||
state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3);
|
||||
let connected_ips = state.get_connected_ips();
|
||||
state.remove_inactive_nodes(HEARTBEAT_INTERVAL * 3).await;
|
||||
let connected_ips = state.get_connected_ips().await;
|
||||
println!("Connected nodes ({}): {:?}", connected_ips.len(), connected_ips);
|
||||
let _ = tx.send((state.get_my_ip(), state.get_my_info()).into());
|
||||
let _ = tx.send((state.get_my_ip().await, state.get_my_info().await).into());
|
||||
if connected_ips.len() < NUM_CONNECTIONS {
|
||||
if let Some(node_ip) = state.get_random_disconnected_ip() {
|
||||
if let Some(node_ip) = state.get_random_disconnected_ip().await {
|
||||
println!("Dialing random node {}", node_ip);
|
||||
tasks.spawn(grpc_new_conn(node_ip, state.clone(), ra_cfg.clone(), tx.clone()));
|
||||
}
|
||||
@ -77,7 +77,7 @@ async fn init_token(state: Arc<State>, ra_cfg: RaTlsConfig) -> SolClient {
|
||||
}
|
||||
Err(SealError::Attack(e)) => {
|
||||
println!("The sealed keys are corrupted: {}", e);
|
||||
state.increase_disk_attacks();
|
||||
state.increase_disk_attacks().await;
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Could not read sealed keys: {e}");
|
||||
|
Loading…
Reference in New Issue
Block a user