From 2a87efacc13ac82ac4d9ea9145e0567252952d9e Mon Sep 17 00:00:00 2001 From: Valentyn Faychuk Date: Mon, 2 Dec 2024 03:39:27 +0200 Subject: [PATCH] metrics and refactoring Signed-off-by: Valentyn Faychuk --- Cargo.lock | 40 +++++++++++++++-- Cargo.toml | 4 +- src/datastore.rs | 52 ++++++++--------------- src/grpc/client.rs | 69 ++++++++++++++---------------- src/grpc/mod.rs | 8 ++-- src/grpc/server.rs | 98 ++++++++++++++++++++---------------------- src/http_server.rs | 37 +++++++++------- src/main.rs | 104 ++++++++++++++++++++++++++++----------------- src/persistence.rs | 53 +++++++++++++---------- src/solana.rs | 54 ++++++++++++++++++----- 10 files changed, 298 insertions(+), 221 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d4fcbc1..6c1e756 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1490,6 +1490,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -1525,7 +1526,7 @@ dependencies = [ [[package]] name = "detee-sgx" version = "0.1.0" -source = "git+ssh://git@gitea.detee.cloud/SGX/detee-sgx#a47753a8e07ef533cca5df41bea4893c9eeb133e" +source = "git+ssh://git@gitea.detee.cloud/SGX/detee-sgx?branch=hacker-challenge#2f032b5c5448be40feda466bb457821ef814f959" dependencies = [ "aes-gcm", "base64 0.22.1", @@ -2042,6 +2043,7 @@ dependencies = [ "rustls 0.23.14", "serde", "serde_json", + "serde_with 3.11.0", "solana-client", "solana-program", "solana-sdk", @@ -2418,6 +2420,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -2428,6 +2431,7 @@ checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown 0.15.0", + "serde", ] [[package]] @@ -3893,7 +3897,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07ff71d2c147a7b57362cead5e22f772cd52f6ab31cfcd9edcd7f6aeb2a0afbe" dependencies = [ "serde", - "serde_with_macros", + "serde_with_macros 2.3.3", +] + +[[package]] +name = "serde_with" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.6.0", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros 3.11.0", + "time", ] [[package]] @@ -3908,6 +3930,18 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "serde_with_macros" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "sha1" version = "0.10.6" @@ -4510,7 +4544,7 @@ dependencies = [ "serde_bytes", "serde_derive", "serde_json", - "serde_with", + "serde_with 2.3.3", "sha2 0.10.8", "sha3 0.10.8", "siphasher", diff --git a/Cargo.toml b/Cargo.toml index 42d2216..2d60b3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ prost = "0.13" prost-types = "0.13" rand = "0.8" serde = { version = "1.0", features = ["derive"] } +serde_with = { version = "3.11", features = ["macros", "base64"] } serde_json = "1.0" solana-client = "2.0" solana-program = "2.0" @@ -36,8 +37,7 @@ hyper-util = "0.1.7" hyper-rustls = { version = "0.27", features = ["http2"] } base64 = "0.22" lazy_static = "1.5" -# TODO: create a feature for testing, make occlum feature optional and added only if not compiling for testing -detee-sgx = { git = "ssh://git@gitea.detee.cloud/SGX/detee-sgx", features = ["tonic", "occlum", "sealing"] } +detee-sgx = { git = "ssh://git@gitea.detee.cloud/SGX/detee-sgx", branch = "hacker-challenge", features = ["tonic", "occlum", "sealing"] } env_logger = "0.11" diff --git a/src/datastore.rs b/src/datastore.rs index b9218ff..2e2edbd 100644 --- a/src/datastore.rs +++ b/src/datastore.rs @@ -1,8 +1,6 @@ -#![allow(dead_code)] -use crate::solana::Client as SolClient; +use crate::persistence::Logfile; use dashmap::{DashMap, DashSet}; use std::time::{Duration, SystemTime}; -use crate::persistence::Logfile; type IP = String; pub const LOCALHOST: &str = "localhost"; @@ -14,7 +12,7 @@ pub struct NodeInfo { pub mint_requests: u64, pub mints: u64, pub mratls_conns: u64, - pub quote_attacks: u64, + pub net_attacks: u64, pub public: bool, pub restarts: u64, pub disk_attacks: u64, @@ -28,7 +26,8 @@ impl NodeInfo { } } if self.mint_requests > other_node.mint_requests - || self.quote_attacks > other_node.quote_attacks + || self.net_attacks > other_node.net_attacks + || self.disk_attacks > other_node.disk_attacks || self.mratls_conns > other_node.mratls_conns || self.mints > other_node.mints || (self.public && !other_node.public) @@ -45,19 +44,17 @@ impl NodeInfo { } } -/// Keypair must already be known when creating a Store -/// This means the first node of the network creates the key -/// Second node will grab the key from the first node -pub struct Store { - sol_client: SolClient, +/// Multithreaded state, designed to be +/// shared everywhere in the code +pub struct State { nodes: DashMap, conns: DashSet, } -impl Store { - pub fn init(sol_client: SolClient) -> Self { - let store = Self { sol_client, nodes: DashMap::new(), conns: DashSet::new() }; - store.nodes.insert( +impl State { + pub fn new() -> Self { + let state = Self { nodes: DashMap::new(), conns: DashSet::new() }; + state.nodes.insert( LOCALHOST.to_string(), NodeInfo { started_at: SystemTime::now(), @@ -65,25 +62,13 @@ impl Store { mint_requests: 0, mints: 0, mratls_conns: 0, - quote_attacks: 0, + net_attacks: 0, public: false, restarts: 0, disk_attacks: 0, }, ); - store - } - - pub fn get_token_address(&self) -> String { - self.sol_client.token_address() - } - - pub fn get_pubkey_base58(&self) -> String { - self.sol_client.wallet_address() - } - - pub fn get_keypair_bytes(&self) -> Vec { - self.sol_client.get_keypair_bytes() + state } pub fn add_conn(&self, ip: &str) { @@ -122,12 +107,11 @@ impl Store { } } - pub fn mint(&self, recipient: &str) -> Result> { - use std::str::FromStr; - let recipient = solana_sdk::pubkey::Pubkey::from_str(recipient)?; - let sig = self.sol_client.mint(&recipient)?; - self.increase_mints(); - Ok(sig) + pub fn increase_net_attacks(&self) { + if let Some(mut localhost_info) = self.nodes.get_mut(LOCALHOST) { + localhost_info.net_attacks += 1; + localhost_info.log(localhost_info.key()); + } } pub fn get_localhost(&self) -> NodeInfo { diff --git a/src/grpc/client.rs b/src/grpc/client.rs index 5ac5a50..fc700b1 100644 --- a/src/grpc/client.rs +++ b/src/grpc/client.rs @@ -1,10 +1,9 @@ -#![allow(dead_code)] -use super::challenge::NodeUpdate; +use super::challenge::{Keys, NodeUpdate}; use crate::{ - datastore::{Store, LOCALHOST}, + datastore::{State, LOCALHOST}, grpc::challenge::{update_client::UpdateClient, Empty}, }; -use solana_sdk::{pubkey::Pubkey, signature::keypair::Keypair}; +use detee_sgx::RaTlsConfig; use std::{str::FromStr, sync::Arc}; use tokio::{ sync::broadcast::Sender, @@ -14,13 +13,14 @@ use tokio_stream::{wrappers::BroadcastStream, StreamExt}; #[derive(Clone)] pub struct ConnManager { - ds: Arc, + state: Arc, tx: Sender, + ratls_config: RaTlsConfig, } impl ConnManager { - pub fn init(ds: Arc, tx: Sender) -> Self { - Self { ds, tx } + pub fn init(state: Arc, ratls_config: RaTlsConfig, tx: Sender) -> Self { + Self { state, ratls_config, tx } } pub async fn start_with_node(self, node_ip: String) { @@ -29,7 +29,7 @@ impl ConnManager { pub async fn start(self) { loop { - if let Some(node) = self.ds.get_random_node() { + if let Some(node) = self.state.get_random_node() { if node != LOCALHOST { self.connect_wrapper(node.clone()).await; } @@ -39,7 +39,7 @@ impl ConnManager { } async fn connect_wrapper(&self, node_ip: String) { - let ds = self.ds.clone(); + let ds = self.state.clone(); ds.add_conn(&node_ip); if let Err(e) = self.connect(node_ip.clone()).await { println!("Client connection for {node_ip} failed: {e:?}"); @@ -48,20 +48,14 @@ impl ConnManager { } async fn connect(&self, node_ip: String) -> Result<(), Box> { - use detee_sgx::{prelude::*, RaTlsConfigBuilder}; + use detee_sgx::RaTlsConfigBuilder; use hyper::Uri; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use tokio_rustls::rustls::ClientConfig; println!("Connecting to {node_ip}..."); - let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76"; - let mrsigner = - crate::sgx::mrsigner_from_hex(mrsigner_hex).expect("mrsigner decoding failed"); - let config = RaTlsConfig::new() - .allow_instance_measurement(InstanceMeasurement::new().with_mrsigners(vec![mrsigner])); - - let tls = ClientConfig::from_ratls_config(config) + let tls = ClientConfig::from_ratls_config(self.ratls_config.clone()) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?; let mut http = HttpConnector::new(); @@ -93,10 +87,14 @@ impl ConnManager { let rx = self.tx.subscribe(); let rx_stream = BroadcastStream::new(rx).filter_map(|n| n.ok()); - let response = client.get_updates(rx_stream).await?; + let response = client.get_updates(rx_stream).await.map_err(|e| { + println!("Error connecting to {node_ip}: {e}"); + self.state.increase_net_attacks(); + e + })?; let mut resp_stream = response.into_inner(); - let _ = self.tx.send((LOCALHOST.to_string(), self.ds.get_localhost()).into()); + let _ = self.tx.send((LOCALHOST.to_string(), self.state.get_localhost()).into()); while let Some(mut update) = resp_stream.message().await? { // "localhost" IPs need to be changed to the real IP of the counterpart @@ -108,7 +106,7 @@ impl ConnManager { } // update the entire network in case the information is new - if self.ds.process_node_update(update.clone().into()) + if self.state.process_node_update(update.clone().into()) && self.tx.send(update.clone()).is_err() { println!("tokio broadcast receivers had an issue consuming the channel"); @@ -119,20 +117,19 @@ impl ConnManager { } } -pub async fn key_grabber(node_ip: String) -> Result<(Keypair, Pubkey), Box> { - use detee_sgx::{prelude::*, RaTlsConfigBuilder}; +pub async fn key_grabber( + state: Arc, + node_ip: String, + ratls_config: RaTlsConfig, +) -> Result> { + use detee_sgx::RaTlsConfigBuilder; use hyper::Uri; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use tokio_rustls::rustls::ClientConfig; println!("Getting key from {node_ip}..."); - let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76"; - let mrsigner = crate::sgx::mrsigner_from_hex(mrsigner_hex).expect("mrsigner decoding failed"); - let config = RaTlsConfig::new() - .allow_instance_measurement(InstanceMeasurement::new().with_mrsigners(vec![mrsigner])); - - let tls = ClientConfig::from_ratls_config(config) + let tls = ClientConfig::from_ratls_config(ratls_config) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?; let mut http = HttpConnector::new(); @@ -160,14 +157,10 @@ pub async fn key_grabber(node_ip: String) -> Result<(Keypair, Pubkey), Box k, - Err(_) => return Err("Could not parse keypair.".into()), - }; - let token_address = Pubkey::from_str(&response.token_address) - .map_err(|_| "Could not parse wallet address.".to_string())?; - Ok((keypair, token_address)) + let response = client.get_keys(tonic::Request::new(Empty {})).await.map_err(|e| { + println!("Error getting keys from {node_ip}: {e}"); + state.increase_net_attacks(); + e + })?; + Ok(response.into_inner()) } diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 336d24b..f29ca19 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -16,10 +16,10 @@ impl From<(String, NodeInfo)> for NodeUpdate { mint_requests: info.mint_requests, mints: info.mints, mratls_conns: info.mratls_conns, - quote_attacks: info.quote_attacks, + quote_attacks: info.net_attacks, public: info.public, restarts: info.restarts, - disk_attacks: info.disk_attacks + disk_attacks: info.disk_attacks, } } } @@ -47,10 +47,10 @@ impl From for (String, NodeInfo) { mint_requests: val.mint_requests, mints: val.mints, mratls_conns: val.mratls_conns, - quote_attacks: val.quote_attacks, + net_attacks: val.quote_attacks, public: val.public, restarts: val.restarts, - disk_attacks: val.disk_attacks + disk_attacks: val.disk_attacks, }; (ip, self_info) } diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 1ed9f58..3625c8f 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -1,23 +1,30 @@ -#![allow(dead_code)] - use super::challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate}; -use crate::{datastore::Store, grpc::challenge::update_server::Update}; +use crate::{datastore::State, grpc::challenge::update_server::Update}; +use detee_sgx::RaTlsConfig; use std::{pin::Pin, sync::Arc}; use tokio::sync::broadcast::Sender; use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; pub struct MyServer { - ds: Arc, + state: Arc, tx: Sender, + ratls_config: RaTlsConfig, + keys: Keys, // For sending secret keys to new nodes ;) } impl MyServer { - pub fn init(ds: Arc, tx: Sender) -> Self { - Self { ds, tx } + pub fn init( + state: Arc, + keys: Keys, + ratls_config: RaTlsConfig, + tx: Sender, + ) -> Self { + Self { state, tx, keys, ratls_config } } pub async fn start(self) { + use detee_sgx::RaTlsConfigBuilder; use hyper::server::conn::http2::Builder; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -29,22 +36,12 @@ impl MyServer { use tonic::{body::boxed, service::Routes}; use tower::{ServiceBuilder, ServiceExt}; - use detee_sgx::{prelude::*, RaTlsConfigBuilder}; - - // TODO: ratls config should be global // TODO: error handling, shouldn't have expects - let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76"; - let mrsigner = - crate::sgx::mrsigner_from_hex(mrsigner_hex).expect("mrsigner decoding failed"); - let config = RaTlsConfig::new() - .allow_instance_measurement(InstanceMeasurement::new().with_mrsigners(vec![mrsigner])); - - let mut tls = ServerConfig::from_ratls_config(config) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e))) - .expect("failed to create server config"); + let mut tls = ServerConfig::from_ratls_config(self.ratls_config.clone()).unwrap(); tls.alpn_protocols = vec![b"h2".to_vec()]; + let state = self.state.clone(); let svc = Routes::new(UpdateServer::new(self)).prepare(); let http = Builder::new(TokioExecutor::new()); @@ -53,10 +50,11 @@ impl MyServer { let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); loop { - let (conn, addr) = match listener.accept().await { + let (conn, _addr) = match listener.accept().await { Ok(incoming) => incoming, Err(e) => { - eprintln!("Error accepting connection: {}", e); + println!("Error accepting connection: {}", e); + state.increase_net_attacks(); continue; } }; @@ -65,6 +63,7 @@ impl MyServer { let tls_acceptor = tls_acceptor.clone(); let svc = svc.clone(); + let state = state.clone(); tokio::spawn(async move { let mut certificates = Vec::new(); @@ -76,30 +75,29 @@ impl MyServer { } } }) + .await; + + let conn = if let Err(e) = conn { + println!("Error accepting TLS connection: {}", e); + state.increase_net_attacks(); + return; + } else { + conn.unwrap() + }; + + let svc = ServiceBuilder::new().service(svc); + + if let Err(e) = http + .serve_connection( + TokioIo::new(conn), + TowerToHyperService::new( + svc.map_request(|req: hyper::Request<_>| req.map(boxed)), + ), + ) .await - .unwrap(); - - #[derive(Debug)] - pub struct ConnInfo { - pub addr: std::net::SocketAddr, - pub certificates: Vec>, + { + println!("Error serving connection: {}", e); } - - let extension_layer = - tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo { - addr, - certificates, - })); - let svc = ServiceBuilder::new().layer(extension_layer).service(svc); - - http.serve_connection( - TokioIo::new(conn), - TowerToHyperService::new( - svc.map_request(|req: hyper::Request<_>| req.map(boxed)), - ), - ) - .await - .unwrap(); }); } } @@ -109,16 +107,20 @@ impl MyServer { impl Update for MyServer { type GetUpdatesStream = Pin> + Send>>; + async fn get_keys(&self, _request: Request) -> Result, Status> { + Ok(Response::new(self.keys.clone())) + } + async fn get_updates( &self, req: Request>, ) -> Result, Status> { - self.ds.increase_mratls_conns(); + self.state.increase_mratls_conns(); let remote_ip = req.remote_addr().unwrap().ip().to_string(); let tx = self.tx.clone(); let mut rx = self.tx.subscribe(); let mut inbound = req.into_inner(); - let ds = self.ds.clone(); + let ds = self.state.clone(); let stream = async_stream::stream! { let full_update_list: Vec = ds.get_node_list().into_iter().map(Into::::into).collect(); @@ -161,12 +163,4 @@ impl Update for MyServer { Ok(Response::new(Box::pin(stream) as Self::GetUpdatesStream)) } - - async fn get_keys(&self, _request: Request) -> Result, Status> { - let reply = Keys { - keypair: self.ds.get_keypair_bytes(), - token_address: self.ds.get_token_address(), - }; - Ok(Response::new(reply)) - } } diff --git a/src/http_server.rs b/src/http_server.rs index e1fa2e9..97dbf85 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -1,5 +1,4 @@ -#![allow(dead_code)] -use crate::{datastore, datastore::Store}; +use crate::{datastore, datastore::State, solana::SolClient}; use actix_web::{get, post, web, App, HttpResponse, HttpServer, Responder}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -8,11 +7,11 @@ use std::sync::Arc; const HOMEPAGE: &str = include_str!("HOMEPAGE.md"); #[get("/")] -async fn homepage(ds: web::Data>) -> impl Responder { +async fn homepage(sol_client: web::Data>) -> impl Responder { let text = HOMEPAGE .to_string() - .replace("TOKEN_ADDRESS", &ds.get_token_address()) - .replace("MINT_AUTHORITY", &ds.get_pubkey_base58()); + .replace("TOKEN_ADDRESS", &sol_client.get_token_address()) + .replace("MINT_AUTHORITY", &sol_client.get_wallet_pubkey()); HttpResponse::Ok().body(text) } @@ -40,7 +39,7 @@ impl From<(String, datastore::NodeInfo)> for NodesResp { last_keepalive, mints: node_info.mints, total_ratls_conns: node_info.mratls_conns, - ratls_attacks: node_info.quote_attacks, + ratls_attacks: node_info.net_attacks, public: node_info.public, mint_requests: node_info.mint_requests, } @@ -48,7 +47,7 @@ impl From<(String, datastore::NodeInfo)> for NodesResp { } #[get("/nodes")] -async fn get_nodes(ds: web::Data>) -> HttpResponse { +async fn get_nodes(ds: web::Data>) -> HttpResponse { HttpResponse::Ok().json( ds.get_node_list().into_iter().map(Into::::into).collect::>(), ) @@ -60,22 +59,30 @@ struct MintReq { } #[post("/mint")] -async fn mint(ds: web::Data>, req: web::Json) -> impl Responder { - ds.increase_mint_requests(); +async fn mint( + state: web::Data>, + sol_client: web::Data>, + req: web::Json, +) -> impl Responder { + let recipient = req.into_inner().wallet; + state.increase_mint_requests(); let result = - web::block(move || ds.mint(&req.into_inner().wallet).map_err(|e| e.to_string())) - .await - .unwrap(); // TODO: check if this can get a BlockingError + web::block(move || sol_client.mint(&recipient).map_err(|e| e.to_string())).await.unwrap(); // TODO: check if this can get a BlockingError + match result { - Ok(s) => HttpResponse::Ok().body(format!(r#"{{" signature": "{s} "}}"#)), + Ok(s) => { + state.increase_mints(); + HttpResponse::Ok().body(format!(r#"{{" signature": "{s} "}}"#)) + } Err(e) => HttpResponse::InternalServerError().body(format!(r#"{{ "error": "{e}" }}"#)), } } -pub async fn init(ds: Arc) { +pub async fn init(state: Arc, sol_client: Arc) { HttpServer::new(move || { App::new() - .app_data(web::Data::new(ds.clone())) + .app_data(web::Data::new(state.clone())) + .app_data(web::Data::new(sol_client.clone())) .service(homepage) .service(get_nodes) .service(mint) diff --git a/src/main.rs b/src/main.rs index e6224ff..a850281 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,9 +5,11 @@ mod persistence; mod sgx; mod solana; -use crate::{datastore::LOCALHOST, grpc::challenge::NodeUpdate, solana::Client as SolClient}; -use datastore::Store; -use solana_sdk::signer::Signer; +use crate::{ + datastore::LOCALHOST, grpc::challenge::NodeUpdate, persistence::KeysFile, solana::SolClient, +}; +use datastore::State; +use detee_sgx::{InstanceMeasurement, RaTlsConfig}; use std::{ fs::File, io::{BufRead, BufReader}, @@ -19,52 +21,61 @@ use tokio::{ time::{sleep, Duration}, }; -const INIT_NODES: &str = "/host/detee_challenge_nodes"; -const DISK_PERSISTENCE: &str = "/host/main/TRY_TO_HACK_THIS"; -const MAINTAINED_CONNECTIONS: usize = 3; +const INIT_NODES_FILE: &str = "/host/detee_challenge_nodes"; +const KEYS_FILE: &str = "/host/TRY_TO_HACK_THIS"; +const MAX_CONNECTIONS: usize = 3; -pub async fn localhost_cron(ds: Arc, tx: Sender) { +pub async fn localhost_cron(state: Arc, tx: Sender) { loop { sleep(Duration::from_secs(60)).await; - let _ = tx.send((LOCALHOST.to_string(), ds.get_localhost()).into()); - ds.remove_inactive_nodes(); + let _ = tx.send((LOCALHOST.to_string(), state.get_localhost()).into()); + state.remove_inactive_nodes(); } } -async fn get_sol_client() -> SolClient { - match crate::persistence::Data::read(DISK_PERSISTENCE).await { - Ok(data) => { - let (keypair, token) = data.parse(); - println!("Found the following wallet saved to disk: {}", keypair.pubkey()); - println!("Loading token mint address {}", token); - return SolClient::from(keypair, token); +async fn get_sol_client(state: Arc, ratls_config: RaTlsConfig) -> SolClient { + match KeysFile::read(KEYS_FILE, &state).await { + Ok(keys_file) => { + let sol_client = SolClient::try_from(keys_file).unwrap(); + println!( + "Found the following wallet saved to disk: {}", + sol_client.get_wallet_pubkey() + ); + println!("Loading token mint address {}", sol_client.get_token_address()); + return sol_client; } - Err(e) => println!("Did not find old pubkeys saved to disk: {e}"), + Err(e) => println!("Can't initialize using sealed keys: {e}"), }; - let input = match File::open(INIT_NODES) { - Ok(i) => i, + let init_nodes = match File::open(INIT_NODES_FILE) { + Ok(init_nodes) => init_nodes, Err(_) => { - println!("Could not find remote nodes in the file {INIT_NODES}"); + println!("Can't initialize using init nodes from {INIT_NODES_FILE}"); println!("Starting a new network with a new key..."); return SolClient::new().await; } }; - let buffered = BufReader::new(input); - for line in buffered.lines() { - match grpc::client::key_grabber(line.unwrap()).await { - Ok(bundle) => { + + let init_nodes_reader = BufReader::new(init_nodes); + for init_node_ip in init_nodes_reader.lines().map(|l| l.unwrap()) { + match grpc::client::key_grabber(state.clone(), init_node_ip, ratls_config.clone()).await { + Ok(keys) => { + let sol_client = SolClient::try_from(keys.clone()) + .map_err(|e| { + println!("Received malformed keys from the network: {e}"); + state.increase_net_attacks(); + }) + .unwrap(); println!( "Got keypair from the network. Joining the network using wallet {}", - bundle.0.pubkey() + sol_client.get_wallet_pubkey() ); - println!("The address of the Token is {}", bundle.1); - println!("Saving this data to disk in the file {DISK_PERSISTENCE}"); - let disk_data = crate::persistence::Data::init_from(&bundle.0, &bundle.1).await; - if let Err(e) = disk_data.write(DISK_PERSISTENCE).await { - println!("Could not save data to disk due to: {e}"); + println!("The address of the Token is {}", sol_client.get_token_address()); + println!("Saving this data to disk in the file {KEYS_FILE}"); + if let Err(e) = sol_client.get_keys_file().write(KEYS_FILE).await { + println!("Could not save data to disk: {e}"); } - return SolClient::from(bundle.0, bundle.1); + return sol_client; } Err(e) => { println!("Could not get keypair: {e:?}"); @@ -77,29 +88,42 @@ async fn get_sol_client() -> SolClient { #[tokio::main] async fn main() { env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); + let ratls_config = RaTlsConfig::new() + .allow_instance_measurement(InstanceMeasurement::new().with_current_mrenclave().unwrap()); - let sol_client = get_sol_client().await; - let ds = Arc::new(Store::init(sol_client)); + let state = Arc::new(State::new()); + let sol_client = Arc::new(get_sol_client(state.clone(), ratls_config.clone()).await); let (tx, _) = broadcast::channel(500); let mut tasks = JoinSet::new(); - tasks.spawn(localhost_cron(ds.clone(), tx.clone())); - tasks.spawn(http_server::init(ds.clone())); - tasks.spawn(grpc::server::MyServer::init(ds.clone(), tx.clone()).start()); + tasks.spawn(localhost_cron(state.clone(), tx.clone())); + tasks.spawn(http_server::init(state.clone(), sol_client.clone())); + tasks.spawn( + grpc::server::MyServer::init( + state.clone(), + sol_client.get_keys(), + ratls_config.clone(), + tx.clone(), + ) + .start(), + ); - if let Ok(input) = std::fs::read_to_string(INIT_NODES) { + if let Ok(input) = std::fs::read_to_string(INIT_NODES_FILE) { for line in input.lines() { tasks.spawn( - grpc::client::ConnManager::init(ds.clone(), tx.clone()) + grpc::client::ConnManager::init(state.clone(), ratls_config.clone(), tx.clone()) .start_with_node(line.to_string()), ); } } - for _ in 0..MAINTAINED_CONNECTIONS { - tasks.spawn(grpc::client::ConnManager::init(ds.clone(), tx.clone()).start()); + for _ in 0..MAX_CONNECTIONS { + tasks.spawn( + grpc::client::ConnManager::init(state.clone(), ratls_config.clone(), tx.clone()) + .start(), + ); } while let Some(Ok(_)) = tasks.join_next().await {} diff --git a/src/persistence.rs b/src/persistence.rs index a5a80e0..6de9a81 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -1,44 +1,53 @@ +use crate::{datastore::State, grpc::challenge::Keys}; use serde::{Deserialize, Serialize}; -use solana_sdk::{pubkey::Pubkey, signature::keypair::Keypair}; -use std::str::FromStr; +use serde_with::base64::Base64; +#[serde_with::serde_as] #[derive(Serialize, Deserialize)] -pub struct Data { +pub struct KeysFile { random: String, - keypair: String, + #[serde_as(as = "Base64")] + keypair: Vec, token: String, } -impl Data { - pub async fn init_from(keypair: &Keypair, token: &Pubkey) -> Self { +impl From for KeysFile { + fn from(keys: Keys) -> Self { use rand::{distributions::Alphanumeric, Rng}; - let random_string: String = + let random: String = rand::thread_rng().sample_iter(&Alphanumeric).take(128).map(char::from).collect(); - Self { - random: random_string, - keypair: keypair.to_base58_string(), - token: token.to_string(), - } + Self { keypair: keys.keypair, token: keys.token_address, random } } +} +impl Into for KeysFile { + fn into(self) -> Keys { + Keys { keypair: self.keypair, token_address: self.token } + } +} + +impl KeysFile { pub async fn write(self, path: &str) -> Result<(), Box> { let serialized = serde_json::to_string(&self)?; let sealed = detee_sgx::SealingConfig::new()?.seal_data(serialized.into_bytes())?; tokio::fs::write(path, sealed).await.map_err(Into::into) } - pub async fn read(path: &str) -> Result> { + pub async fn read(path: &str, state: &State) -> Result> { let sealed = tokio::fs::read(path).await?; - let serialized = detee_sgx::SealingConfig::new()?.un_seal_data(sealed)?; + let serialized = detee_sgx::SealingConfig::new()?.un_seal_data(sealed).map_err(|e| { + match e { + detee_sgx::SgxError::UnSealingError(ref ue) => { + state.increase_disk_attacks(); + println!("The disk data is corrupted: {ue}"); + } + _ => println!("Failed to unseal data: {e}"), + }; + e + })?; Ok(serde_json::from_str(&String::from_utf8(serialized)?)?) } - - pub fn parse(self) -> (Keypair, Pubkey) { - let keypair = Keypair::from_base58_string(&self.keypair); - let pubkey = Pubkey::from_str(&self.token).unwrap(); - (keypair, pubkey) - } } const LOG_PATH: &str = "/host/logs"; @@ -48,9 +57,7 @@ pub struct Logfile {} impl Logfile { pub fn append(msg: &str) -> Result<(), Box> { use std::io::Write; - let mut file = std::fs::OpenOptions::new() - .append(true) - .open(LOG_PATH)?; + let mut file = std::fs::OpenOptions::new().append(true).open(LOG_PATH)?; file.write_all(msg.as_bytes())?; Ok(()) } diff --git a/src/solana.rs b/src/solana.rs index 105b356..f62d775 100644 --- a/src/solana.rs +++ b/src/solana.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +use crate::{grpc::challenge::Keys, persistence::KeysFile}; use solana_client::rpc_client::RpcClient; use solana_program::program_pack::Pack; use solana_sdk::{ @@ -12,18 +13,17 @@ use spl_token::{ instruction::{initialize_mint, mint_to}, state::Mint, }; -use std::error::Error; use tokio::time::{sleep, Duration}; const RPC_URL: &str = "https://api.devnet.solana.com"; -pub struct Client { +pub struct SolClient { client: RpcClient, keypair: Keypair, token: Pubkey, } -impl Client { +impl SolClient { pub async fn new() -> Self { let client = RpcClient::new(RPC_URL); let keypair = Keypair::new(); @@ -31,12 +31,18 @@ impl Client { Self { client, keypair, token } } - pub fn from(keypair: Keypair, token: Pubkey) -> Self { - Self { client: RpcClient::new(RPC_URL), keypair, token } + pub fn get_keys(&self) -> Keys { + Keys { keypair: self.get_keypair_bytes(), token_address: self.get_token_address() } } - pub fn mint(&self, recipient: &Pubkey) -> Result> { - let associated_token_address = self.create_token_account(recipient)?; + pub fn get_keys_file(&self) -> KeysFile { + self.get_keys().into() + } + + pub fn mint(&self, recipient: &str) -> Result> { + use std::str::FromStr; + let recipient = solana_sdk::pubkey::Pubkey::from_str(recipient)?; + let associated_token_address = self.create_token_account(&recipient)?; let mint_to_instruction = mint_to( &spl_token::id(), &self.token, @@ -56,7 +62,10 @@ impl Client { Ok(signature.to_string()) } - fn create_token_account(&self, recipient: &Pubkey) -> Result> { + fn create_token_account( + &self, + recipient: &Pubkey, + ) -> Result> { let address = get_associated_token_address(recipient, &self.token); if self.client.get_account(&address).is_err() { let create_token_account_instruction = create_associated_token_account( @@ -77,11 +86,12 @@ impl Client { Ok(address) } - pub fn wallet_address(&self) -> String { + pub fn get_wallet_pubkey(&self) -> String { + // Return the base58 string representation of the public key self.keypair.pubkey().to_string() } - pub fn token_address(&self) -> String { + pub fn get_token_address(&self) -> String { self.token.to_string() } @@ -90,6 +100,30 @@ impl Client { } } +impl TryFrom for SolClient { + type Error = String; + + fn try_from(keys: Keys) -> Result { + use std::str::FromStr; + let keypair = match Keypair::from_bytes(&keys.keypair) { + Ok(k) => k, + Err(_) => return Err("Could not parse keypair.".into()), + }; + let token = Pubkey::from_str(&keys.token_address) + .map_err(|_| "Could not parse wallet address.".to_string())?; + Ok(Self { client: RpcClient::new(RPC_URL), keypair, token }) + } +} + +impl TryFrom for SolClient { + type Error = String; + + fn try_from(keys_file: KeysFile) -> Result { + let keys: Keys = keys_file.into(); + Self::try_from(keys) + } +} + async fn wait_for_sol(client: &RpcClient, pubkey: &Pubkey) { println!("Waiting to receive 0.01 SOL in address {pubkey}"); loop {