hacker-challenge/src/grpc/client.rs
Valentyn Faychuk 611421a9c0
connections refactoring
Signed-off-by: Valentyn Faychuk <valy@detee.ltd>
2024-12-24 00:53:30 +02:00

182 lines
6.1 KiB
Rust

use super::{challenge::Keys, InternalNodeUpdate};
use crate::{
datastore::State,
grpc::challenge::{update_client::UpdateClient, Empty},
};
use detee_sgx::RaTlsConfig;
use std::{net::Ipv4Addr, str::FromStr, sync::Arc};
use tokio::{sync::broadcast::Sender, time::Duration};
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
pub async fn grpc_new_conn(
node_ip: String,
state: Arc<State>,
ra_cfg: RaTlsConfig,
tx: Sender<InternalNodeUpdate>,
) {
if Ipv4Addr::from_str(&node_ip).is_err() {
println!("IPv4 address is invalid: {node_ip}");
return;
}
ConnManager::init(state, ra_cfg, tx).connect_to(node_ip).await
}
pub async fn grpc_query_keys(
node_ip: String,
state: Arc<State>,
ra_cfg: RaTlsConfig,
) -> Result<Keys, Box<dyn std::error::Error>> {
if Ipv4Addr::from_str(&node_ip).is_err() {
let err = format!("IPv4 address is invalid: {node_ip}");
return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, err)));
}
query_keys(node_ip, state, ra_cfg).await
}
#[derive(Clone)]
struct ConnManager {
state: Arc<State>,
tx: Sender<InternalNodeUpdate>,
ra_cfg: RaTlsConfig,
}
impl ConnManager {
fn init(state: Arc<State>, ra_cfg: RaTlsConfig, tx: Sender<InternalNodeUpdate>) -> Self {
Self { state, ra_cfg, tx }
}
async fn connect_to(&self, node_ip: String) {
let state = self.state.clone();
state.add_conn(&node_ip);
if let Err(e) = self.connect_to_int(node_ip.clone()).await {
println!("Client connection for {node_ip} failed: {e:?}");
}
state.delete_conn(&node_ip);
}
async fn connect_to_int(&self, remote_ip: String) -> Result<(), Box<dyn std::error::Error>> {
use detee_sgx::RaTlsConfigBuilder;
use hyper::Uri;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use tokio_rustls::rustls::ClientConfig;
println!("Connecting to {remote_ip}...");
let tls = ClientConfig::from_ratls_config(self.ra_cfg.clone())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?;
let mut http = HttpConnector::new();
http.enforce_http(false);
let cloned_node_ip = remote_ip.clone();
let connector = tower::ServiceBuilder::new()
.layer_fn(move |s| {
let tls = tls.clone();
hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls)
.https_or_http()
.enable_http2()
.wrap_connector(s)
})
.timeout(Duration::from_secs(5))
.map_request(move |_| {
Uri::from_str(&format!("https://{cloned_node_ip}:31373"))
.expect("Could not parse URI")
})
.service(http);
let client =
hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector);
let uri = Uri::from_static("https://example.com");
let mut client = UpdateClient::with_origin(client, uri);
let rx = self.tx.subscribe();
let cloned_remote_ip = remote_ip.clone();
let rx_stream =
BroadcastStream::new(rx).filter_map(|n| n.ok()).filter_map(move |int_update| {
if int_update.sender_ip != cloned_remote_ip {
Some(int_update.update)
} else {
None
}
});
let response = client.get_updates(rx_stream).await.map_err(|e| {
println!("Error connecting to {remote_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") {
self.state.increase_net_attacks();
}
e
})?;
let mut resp_stream = response.into_inner();
// Immediately send our info as a network update
let _ = self.tx.send((self.state.get_my_ip(), self.state.get_my_info()).into());
while let Some(update) = resp_stream.message().await? {
// Update the entire network in case the information is new
if self.state.process_node_update(update.clone().into()) {
// If process update returns true, the update must be forwarded
if self.tx.send((remote_ip.clone(), update).into()).is_err() {
println!("Tokio broadcast receivers had an issue consuming the channel");
};
}
}
Ok(())
}
}
async fn query_keys(
node_ip: String,
state: Arc<State>,
ra_cfg: RaTlsConfig,
) -> Result<Keys, Box<dyn std::error::Error>> {
use detee_sgx::RaTlsConfigBuilder;
use hyper::Uri;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use tokio_rustls::rustls::ClientConfig;
println!("Getting key from {node_ip}...");
let tls = ClientConfig::from_ratls_config(ra_cfg)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?;
let mut http = HttpConnector::new();
http.enforce_http(false);
let cloned_node_ip = node_ip.clone();
let connector = tower::ServiceBuilder::new()
.layer_fn(move |s| {
let tls = tls.clone();
hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls)
.https_or_http()
.enable_http2()
.wrap_connector(s)
})
.timeout(Duration::from_secs(5))
.map_request(move |_| {
Uri::from_str(&format!("https://{cloned_node_ip}:31373")).expect("Could not parse URI")
})
.service(http);
let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector);
let uri = Uri::from_static("https://example.com");
let mut client = UpdateClient::with_origin(client, uri);
let response = client.get_keys(tonic::Request::new(Empty {})).await.map_err(|e| {
println!("Error getting keys from {node_ip}: {e}");
if e.to_string().contains("QuoteVerifyError") {
state.increase_net_attacks();
}
e
})?;
Ok(response.into_inner())
}