hacker-challenge-sgx-general/src/datastore.rs

255 lines
8.6 KiB
Rust

use crate::persistence::{SealError, SealedFile};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use std::collections::{HashMap, HashSet};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
type IP = String;
const LOCAL_INFO_FILE: &str = "/host/main/node_info";
#[serde_as]
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Ord, PartialOrd)]
pub struct NodeInfo {
#[serde_as(as = "TimestampSeconds")]
pub started_at: SystemTime,
#[serde_as(as = "TimestampSeconds")]
pub keepalive: SystemTime,
pub mint_requests: u64,
pub mints: u64,
pub mratls_conns: u64,
pub net_attacks: u64,
pub public: bool,
pub restarts: u64,
pub disk_attacks: u64,
}
impl NodeInfo {
pub fn new_empty() -> Self {
NodeInfo {
started_at: SystemTime::now(),
keepalive: SystemTime::now(),
mint_requests: 0,
mints: 0,
mratls_conns: 0,
net_attacks: 0,
public: false,
restarts: 0,
disk_attacks: 0,
}
}
pub fn is_newer_than(&self, older_self: &Self) -> bool {
self.keepalive > older_self.keepalive
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap() // can fail only if time goes backwards :D
}
pub fn to_metrics(&self, ip: &str) -> String {
let started_at = self.started_at.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO);
let keepalive = self.keepalive.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO);
let labels = format!("{{ip=\"{}\", public=\"{}\"}}", ip, self.public);
let mut res = String::new();
res.push_str(&format!("started_at{} {}\n", labels, started_at.as_secs()));
res.push_str(&format!("keepalive{} {}\n", labels, keepalive.as_secs()));
res.push_str(&format!("mint_requests{} {}\n", labels, self.mint_requests));
res.push_str(&format!("mints{} {}\n", labels, self.mints));
res.push_str(&format!("mratls_conns{} {}\n", labels, self.mratls_conns));
res.push_str(&format!("net_attacks{} {}\n", labels, self.net_attacks));
res.push_str(&format!("restarts{} {}\n", labels, self.restarts));
res.push_str(&format!("disk_attacks{} {}\n", labels, self.disk_attacks));
res
}
pub fn load() -> Self {
match Self::read(LOCAL_INFO_FILE) {
Ok(mut info) => {
info.mratls_conns = 0;
info.restarts += 1;
info
}
Err(SealError::Attack(e)) => {
println!("The local node file is corrupted: {}", e);
let mut info = Self::new_empty();
info.disk_attacks += 1; // add very first disk attack
info
}
Err(_) => Self::new_empty(),
}
}
pub fn save(&self) {
if let Err(e) = self.write(LOCAL_INFO_FILE) {
println!("Could not save node info: {}", e);
}
}
}
/// Multithreaded state, designed to be
/// shared everywhere in the code
pub struct State {
my_ip: String,
nodes: RwLock<HashMap<IP, NodeInfo>>,
conns: RwLock<HashSet<IP>>,
}
impl State {
pub fn new(my_ip: String) -> Self {
let mut nodes = HashMap::new();
let my_info = NodeInfo::load();
nodes.insert(my_ip.clone(), my_info);
Self { my_ip, nodes: RwLock::new(nodes), conns: RwLock::new(HashSet::new()) }
}
pub async fn add_conn(&self, ip: &str) {
self.increase_mratls_conns().await;
self.add_mratls_conn(ip).await;
}
pub async fn delete_conn(&self, ip: &str) {
self.decrease_mratls_conns().await;
self.delete_mratls_conn(ip).await;
}
async fn add_mratls_conn(&self, ip: &str) {
let mut conns = self.conns.write().await;
conns.insert(ip.to_string());
}
async fn delete_mratls_conn(&self, ip: &str) {
let mut conns = self.conns.write().await;
conns.remove(ip);
}
pub async fn increase_mint_requests(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.mint_requests += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
pub async fn increase_mints(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.mints += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
pub async fn increase_mratls_conns(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.mratls_conns += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
pub async fn decrease_mratls_conns(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
if my_info.mratls_conns > 0 {
let mut updated_info = my_info.clone();
updated_info.mratls_conns -= 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
}
pub async fn increase_disk_attacks(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.disk_attacks += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
pub async fn increase_net_attacks(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.net_attacks += 1;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
pub async fn declare_myself_public(&self) {
let mut nodes = self.nodes.write().await;
if let Some(my_info) = nodes.get_mut(&self.my_ip) {
let mut updated_info = my_info.clone();
updated_info.public = true;
let _ = nodes.insert(self.my_ip.clone(), updated_info);
}
}
pub async fn get_nodes(&self) -> Vec<(String, NodeInfo)> {
let nodes = self.nodes.read().await;
nodes.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}
pub async fn get_my_ip(&self) -> String {
self.my_ip.clone()
}
pub async fn get_my_info(&self) -> NodeInfo {
let nodes = self.nodes.read().await;
nodes.get(&self.my_ip).cloned().unwrap_or(NodeInfo::new_empty())
}
pub async fn get_connected_ips(&self) -> Vec<String> {
let conns = self.conns.read().await;
conns.iter().map(|n| n.clone()).collect()
}
// returns a random node that does not have an active connection
pub async fn get_random_disconnected_ip(&self) -> Option<String> {
use rand::{rngs::OsRng, RngCore};
let conn_ips = self.get_connected_ips().await;
let nodes = self.nodes.read().await;
if nodes.is_empty() {
return None;
}
let skip = OsRng.next_u64() as usize % nodes.len();
nodes
.keys()
.map(|ip| ip.to_string())
.filter(|ip| ip != &self.my_ip && !conn_ips.contains(ip))
.cycle()
.nth(skip)
}
/// This returns true if the update should be further forwarded
/// For example, we never forward our own updates that came back
pub async fn process_node_update(&self, (node_ip, node_info): (String, NodeInfo)) -> bool {
let is_update_mine = node_ip.eq(&self.my_ip);
let mut nodes = self.nodes.write().await;
let is_update_new = nodes
.get(&node_ip)
.map(|curr_info| node_info.is_newer_than(&curr_info))
.unwrap_or(true);
if is_update_new {
println!("Inserting: {}, {}", node_ip, node_info.to_json());
let _ = nodes.insert(node_ip, node_info);
}
is_update_new && !is_update_mine
}
pub async fn remove_inactive_nodes(&self, max_age: u64) {
let mut nodes = self.nodes.write().await;
// TODO: Check if it is possible to corrupt SGX system time
let now = SystemTime::now();
nodes.retain(|_, n| {
let age = now.duration_since(n.keepalive).unwrap_or(Duration::ZERO).as_secs();
age <= max_age
});
}
}