hacker-challenge-sgx-general/src/grpc/client.rs
jakubDoka 09a84a15f3 rewrite (#2)
running clippy fix
separating homepage to a file
adding summary of network security
removing the rewrite structure
removing catch unwind
adding sealing to persistence
redirectng to the upstream
fixing some startup endgecases

Co-authored-by: Jakub Doka <jakub.doka2@gmail.com>
Reviewed-on: SGX/hacker-challenge-sgx#2
2024-11-08 14:33:42 +00:00

174 lines
6.2 KiB
Rust

#![allow(dead_code)]
use super::challenge::NodeUpdate;
use crate::{
datastore::{Store, LOCALHOST},
grpc::challenge::{update_client::UpdateClient, Empty},
};
use solana_sdk::{pubkey::Pubkey, signature::keypair::Keypair};
use std::{str::FromStr, sync::Arc};
use tokio::{
sync::broadcast::Sender,
time::{sleep, Duration},
};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
#[derive(Clone)]
pub struct ConnManager {
ds: Arc<Store>,
tx: Sender<NodeUpdate>,
}
impl ConnManager {
pub fn init(ds: Arc<Store>, tx: Sender<NodeUpdate>) -> Self {
Self { ds, tx }
}
pub async fn start_with_node(self, node_ip: String) {
self.connect_wrapper(node_ip).await;
}
pub async fn start(self) {
loop {
if let Some(node) = self.ds.get_random_node() {
if node != LOCALHOST {
self.connect_wrapper(node.clone()).await;
}
}
sleep(Duration::from_secs(3)).await;
}
}
async fn connect_wrapper(&self, node_ip: String) {
let ds = self.ds.clone();
ds.add_conn(&node_ip);
if let Err(e) = self.connect(node_ip.clone()).await {
println!("Client connection for {node_ip} failed: {e:?}");
}
ds.delete_conn(&node_ip);
}
async fn connect(&self, node_ip: String) -> Result<(), Box<dyn std::error::Error>> {
use detee_sgx::{prelude::*, RaTlsConfigBuilder};
use hyper::Uri;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use tokio_rustls::rustls::ClientConfig;
println!("Connecting to {node_ip}...");
let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76";
let mrsigner =
crate::sgx::mrsigner_from_hex(mrsigner_hex).expect("mrsigner decoding failed");
let config = RaTlsConfig::new()
.allow_instance_measurement(InstanceMeasurement::new().with_mrsigners(vec![mrsigner]));
let tls = ClientConfig::from_ratls_config(config)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?;
let mut http = HttpConnector::new();
http.enforce_http(false);
let cloned_node_ip = node_ip.clone();
let connector = tower::ServiceBuilder::new()
.layer_fn(move |s| {
let tls = tls.clone();
hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls)
.https_or_http()
.enable_http2()
.wrap_connector(s)
})
.map_request(move |_| {
Uri::from_str(&format!("https://{cloned_node_ip}:31373"))
.expect("Could not parse URI")
})
.service(http);
let client =
hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector);
let uri = Uri::from_static("https://example.com");
let mut client = UpdateClient::with_origin(client, uri);
let rx = self.tx.subscribe();
let rx_stream = BroadcastStream::new(rx).filter_map(|n| n.ok());
let response = client.get_updates(rx_stream).await?;
let mut resp_stream = response.into_inner();
let _ = self.tx.send((LOCALHOST.to_string(), self.ds.get_localhost()).into());
while let Some(mut update) = resp_stream.message().await? {
// "localhost" IPs need to be changed to the real IP of the counterpart
if update.ip == LOCALHOST {
update.ip = node_ip.clone();
// since we are connecting TO this server, we have a guarantee that this
// server is not behind NAT, so we can set it public
update.public = true;
}
// update the entire network in case the information is new
if self.ds.process_node_update(update.clone().into()).await
&& self.tx.send(update.clone()).is_err()
{
println!("tokio broadcast receivers had an issue consuming the channel");
};
}
Ok(())
}
}
pub async fn key_grabber(node_ip: String) -> Result<(Keypair, Pubkey), Box<dyn std::error::Error>> {
use detee_sgx::{prelude::*, RaTlsConfigBuilder};
use hyper::Uri;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use tokio_rustls::rustls::ClientConfig;
println!("Getting key from {node_ip}...");
let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76";
let mrsigner = crate::sgx::mrsigner_from_hex(mrsigner_hex).expect("mrsigner decoding failed");
let config = RaTlsConfig::new()
.allow_instance_measurement(InstanceMeasurement::new().with_mrsigners(vec![mrsigner]));
let tls = ClientConfig::from_ratls_config(config)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?;
let mut http = HttpConnector::new();
http.enforce_http(false);
let cloned_node_ip = node_ip.clone();
let connector = tower::ServiceBuilder::new()
.layer_fn(move |s| {
let tls = tls.clone();
hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls)
.https_or_http()
.enable_http2()
.wrap_connector(s)
})
.map_request(move |_| {
Uri::from_str(&format!("https://{cloned_node_ip}:31373")).expect("Could not parse URI")
})
.service(http);
let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector);
let uri = Uri::from_static("https://example.com");
let mut client = UpdateClient::with_origin(client, uri);
let response = client.get_keys(tonic::Request::new(Empty {})).await?;
let response = &response.into_inner();
let keypair = response.keypair.clone();
let keypair = match Keypair::from_bytes(&keypair) {
Ok(k) => k,
Err(_) => return Err("Could not parse keypair.".into()),
};
let token_address = Pubkey::from_str(&response.token_address)
.map_err(|_| "Could not parse wallet address.".to_string())?;
Ok((keypair, token_address))
}