diff --git a/src/grpc/client.rs b/src/grpc/client.rs index 3ae2f4a..2a79cf0 100644 --- a/src/grpc/client.rs +++ b/src/grpc/client.rs @@ -1,4 +1,5 @@ -use super::challenge::{Keys, NodeUpdate}; +use super::challenge::Keys; +use super::InternalNodeUpdate; use crate::{ datastore::State, grpc::challenge::{update_client::UpdateClient, Empty}, @@ -15,7 +16,7 @@ use tokio_stream::{wrappers::BroadcastStream, StreamExt}; pub struct ConnManager { my_ip: String, state: Arc, - tx: Sender, + tx: Sender, ratls_config: RaTlsConfig, } @@ -24,7 +25,7 @@ impl ConnManager { my_ip: String, state: Arc, ratls_config: RaTlsConfig, - tx: Sender, + tx: Sender, ) -> Self { Self { my_ip, state, ratls_config, tx } } @@ -52,13 +53,13 @@ impl ConnManager { state.delete_conn(&node_ip); } - async fn connect(&self, node_ip: String) -> Result<(), Box> { + async fn connect(&self, remote_ip: String) -> Result<(), Box> { use detee_sgx::RaTlsConfigBuilder; use hyper::Uri; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use tokio_rustls::rustls::ClientConfig; - println!("Connecting to {node_ip}..."); + println!("Connecting to {remote_ip}..."); let tls = ClientConfig::from_ratls_config(self.ratls_config.clone()) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?; @@ -66,7 +67,7 @@ impl ConnManager { let mut http = HttpConnector::new(); http.enforce_http(false); - let cloned_node_ip = node_ip.clone(); + let cloned_node_ip = remote_ip.clone(); let connector = tower::ServiceBuilder::new() .layer_fn(move |s| { @@ -91,9 +92,18 @@ impl ConnManager { let mut client = UpdateClient::with_origin(client, uri); let rx = self.tx.subscribe(); - let rx_stream = BroadcastStream::new(rx).filter_map(|n| n.ok()); + let cloned_remote_ip = remote_ip.clone(); + let rx_stream = + BroadcastStream::new(rx).filter_map(|n| n.ok()).filter_map(move |int_update| { + if int_update.sender_ip != cloned_remote_ip { + Some(int_update.update) + } else { + None + } + }); + let response = client.get_updates(rx_stream).await.map_err(|e| { - println!("Error connecting to {node_ip}: {e}"); + println!("Error connecting to {remote_ip}: {e}"); if e.to_string().contains("QuoteVerifyError") { self.state.increase_net_attacks(); } @@ -101,15 +111,25 @@ impl ConnManager { })?; let mut resp_stream = response.into_inner(); - let _ = self.tx.send((self.my_ip.clone(), self.state.get_my_info()).into()); + // Immediately send our info as a network update + let my_info = (self.my_ip.clone(), self.state.get_my_info()).into(); + let _ = self.tx.send(InternalNodeUpdate { sender_ip: self.my_ip.clone(), update: my_info }); while let Some(update) = resp_stream.message().await? { // update the entire network in case the information is new - if self.state.process_node_update(update.clone().into()) - && self.tx.send(update.clone()).is_err() - { - println!("Tokio broadcast receivers had an issue consuming the channel"); - }; + if self.state.process_node_update(update.clone().into()) { + // if process update returns true, the update must be forwarded + if self + .tx + .send(InternalNodeUpdate { + sender_ip: remote_ip.clone(), + update: update.clone(), + }) + .is_err() + { + println!("Tokio broadcast receivers had an issue consuming the channel"); + }; + } } Ok(()) diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index f29ca19..6f7eae1 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -7,6 +7,12 @@ pub mod challenge { tonic::include_proto!("challenge"); } +#[derive(Clone, PartialEq)] +pub struct InternalNodeUpdate { + pub sender_ip: String, + pub update: NodeUpdate, +} + impl From<(String, NodeInfo)> for NodeUpdate { fn from((ip, info): (String, NodeInfo)) -> Self { NodeUpdate { diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 25a7ec7..7c51697 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -1,4 +1,5 @@ use super::challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate}; +use super::InternalNodeUpdate; use crate::{datastore::State, grpc::challenge::update_server::Update}; use detee_sgx::RaTlsConfig; use rustls::pki_types::CertificateDer; @@ -9,7 +10,7 @@ use tonic::{Request, Response, Status, Streaming}; pub struct MyServer { state: Arc, - tx: Sender, + tx: Sender, ratls_config: RaTlsConfig, keys: Keys, // For sending secret keys to new nodes ;) } @@ -19,7 +20,7 @@ impl MyServer { state: Arc, keys: Keys, ratls_config: RaTlsConfig, - tx: Sender, + tx: Sender, ) -> Self { Self { state, tx, keys, ratls_config } } @@ -160,9 +161,13 @@ impl Update for MyServer { } else { println!("Node {remote_ip} is forwarding the update of {}", update.ip); } - if state.process_node_update(update.clone().into()) && tx.send(update.clone()).is_err() { - println!("Tokio broadcast receivers had an issue consuming the channel"); - }; + + if state.process_node_update(update.clone().into()) { + // if process update returns true, the update must be forwarded + if tx.send(InternalNodeUpdate { sender_ip: remote_ip.clone(), update: update.clone() }).is_err() { + println!("Tokio broadcast receivers had an issue consuming the channel"); + }; + } } Err(e) => { error_status = Status::internal(format!("Error receiving client stream: {}", e)); @@ -171,7 +176,10 @@ impl Update for MyServer { } } Ok(update) = rx.recv() => { - yield Ok(update); + if update.sender_ip != remote_ip { + // don't bounce back the update we just received + yield Ok(update.update); + } // disconnect client if too many connections are active if tx.receiver_count() > 9 { error_status = Status::internal("Already have too many clients. Connect to another server."); diff --git a/src/main.rs b/src/main.rs index 4278009..57bf275 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,8 @@ mod solana; use crate::persistence::SealError; use crate::{ - grpc::challenge::NodeUpdate, persistence::KeysFile, persistence::SealedFile, solana::SolClient, + grpc::challenge::NodeUpdate, grpc::InternalNodeUpdate, persistence::KeysFile, + persistence::SealedFile, solana::SolClient, }; use datastore::State; use detee_sgx::{InstanceMeasurement, RaTlsConfig}; @@ -45,11 +46,12 @@ async fn resolve_my_ip() -> Result { Ok(format!("{}", ip)) } -pub async fn heartbeat_cron(my_ip: String, state: Arc, tx: Sender) { +pub async fn heartbeat_cron(my_ip: String, state: Arc, tx: Sender) { loop { sleep(Duration::from_secs(60)).await; println!("Heartbeat..."); - let _ = tx.send((my_ip.clone(), state.get_my_info()).into()); + let update = (my_ip.clone(), state.get_my_info()).into(); + let _ = tx.send(InternalNodeUpdate { sender_ip: my_ip.clone(), update }); state.remove_inactive_nodes(); } }