diff --git a/Cargo.lock b/Cargo.lock index 194b57a..0f4be4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2236,6 +2236,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 53bdb21..f202878 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ rand = "0.8.5" salvo = { version = "0.70.0", features = ["affix"] } tabled = "0.16.0" tokio = { version = "1.39.2", features = ["macros"] } -tokio-stream = { version = "0.1.15" } +tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = "0.12.1" [build-dependencies] diff --git a/src/datastore.rs b/src/datastore.rs index 5d4e7c6..6b5acbb 100644 --- a/src/datastore.rs +++ b/src/datastore.rs @@ -177,7 +177,7 @@ impl Store { } } - // returns old pubkey if node got updated + /// returns old pubkey if node got updated async fn update_node(&self, ip: String, info: NodeInfo) -> Option { let mut nodes = self.nodes.lock().await; match nodes.insert(ip, info.clone()) { @@ -230,4 +230,27 @@ impl Store { }) .collect() } + + /// you can specify the online argument to get only nodes that are online + pub async fn get_random_nodes(&self, online: bool) -> Vec { + use rand::rngs::OsRng; + use rand::RngCore; + let nodes = self.nodes.lock().await; + let len = nodes.len(); + let skip = OsRng.next_u64().try_into().unwrap_or(0) % len; + let mut iter = nodes.iter().cycle().skip(skip); + let mut random_nodes = vec![]; + let mut count = 0; + let mut iterations = 0; + while count < 3 && iterations < len { + if let Some((ip, info)) = iter.next() { + if online || info.online { + random_nodes.push(ip.clone()); + count -= 1; + } + iterations += 1; + } + } + random_nodes + } } diff --git a/src/grpc/client.rs b/src/grpc/client.rs index e69de29..15fc38c 100644 --- a/src/grpc/client.rs +++ b/src/grpc/client.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +use super::challenge::NodeUpdate; +use crate::datastore::Store; +use crate::grpc::challenge::update_client::UpdateClient; +use std::sync::Arc; +use std::thread::JoinHandle; +use tokio::sync::broadcast::Sender; +use tokio::task::JoinSet; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::StreamExt; + +struct Connection { + ds: Arc, + tx: Sender, +} + +#[derive(Clone)] +struct ConnManager { + ds: Arc, + tx: Sender, +} + +impl ConnManager { + fn init(ds: Arc, tx: Sender) -> Self { + Self { ds, tx } + } + + async fn connect(self, node_ip: String) { + let mut client = UpdateClient::connect(format!("http://{node_ip}:50051")) + .await + .unwrap(); + + let rx = self.tx.subscribe(); + let rx_stream = BroadcastStream::new(rx).filter_map(|n| n.ok()); + let response = client.get_updates(rx_stream).await.unwrap(); + let mut resp_stream = response.into_inner(); + + while let Some(update) = resp_stream.message().await.unwrap() { + if self.ds.process_grpc_update(update.clone()).await { + if let Err(_) = self.tx.send(update.clone()) { + println!("tokio broadcast receivers had an issue consuming the channel"); + } + }; + } + + } +} + +fn init_connections(ds: Arc, tx: Sender) -> JoinHandle<()> { + std::thread::spawn(move || { + tokio::runtime::Runtime::new().unwrap().block_on(async { + let mut only_online_nodes = true; + loop { + let mut set = JoinSet::new(); + let nodes = ds.get_random_nodes(only_online_nodes).await; + for node in nodes { + let conn = ConnManager::init(ds.clone(), tx.clone()); + set.spawn(conn.connect(node)); + } + while let Some(_) = set.join_next().await {} + only_online_nodes = !only_online_nodes; + } + }); + }) +}