diff --git a/proto/challenge.proto b/proto/challenge.proto index 326eac3..4440f7e 100644 --- a/proto/challenge.proto +++ b/proto/challenge.proto @@ -7,7 +7,7 @@ message NodeUpdate { string ip = 1; string keypair = 2; google.protobuf.Timestamp updated_at = 3; - bool online = 4; + bool public = 4; } service Update { diff --git a/src/datastore.rs b/src/datastore.rs index 6dc37b9..74c6117 100644 --- a/src/datastore.rs +++ b/src/datastore.rs @@ -1,8 +1,7 @@ -#![allow(dead_code)] use crate::grpc::challenge::NodeUpdate; use ed25519_dalek::{Signer, SigningKey, VerifyingKey}; use rand::rngs::OsRng; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::time::Duration; use std::time::SystemTime; use std::time::UNIX_EPOCH; @@ -13,12 +12,13 @@ use tokio::sync::Mutex; pub struct NodeInfo { pub pubkey: VerifyingKey, pub updated_at: SystemTime, - pub online: bool, + pub public: bool, } /// Needs to be surrounded in an Arc. pub struct Store { nodes: Mutex>, + conns: Mutex>, keys: Mutex>, } pub enum SigningError { @@ -57,20 +57,32 @@ impl std::fmt::Display for SigningError { } impl Store { + // app should exit if any error happens here so unwrap() is good pub fn init() -> Self { Self { nodes: Mutex::new(HashMap::new()), keys: Mutex::new(HashMap::new()), + conns: Mutex::new(HashSet::new()), } } + pub async fn add_conn(&self, ip: &str) { + let mut conns = self.conns.lock().await; + conns.insert(ip.to_string()); + } + + pub async fn delete_conn(&self, ip: &str) { + let mut conns = self.conns.lock().await; + conns.remove(ip); + } + pub async fn tabled_node_list(&self) -> String { #[derive(Tabled)] struct OutputRow { ip: String, pubkey: String, age: u64, - online: bool, + public: bool, } let mut output = vec![]; for (ip, node_info) in self.nodes.lock().await.iter() { @@ -80,12 +92,12 @@ impl Store { .duration_since(node_info.updated_at) .unwrap_or(Duration::ZERO) .as_secs(); - let online = node_info.online; + let public = node_info.public; output.push(OutputRow { ip, pubkey, age, - online, + public, }); } Table::new(output).to_string() @@ -99,8 +111,7 @@ impl Store { let key_bytes = hex::decode(pubkey)?; let pubkey = VerifyingKey::from_bytes(&key_bytes.as_slice().try_into()?)?; - let key_store = self.keys.lock().await; - let signing_key = match { key_store.get(&pubkey) } { + let signing_key = match self.get_privkey(&pubkey).await { Some(k) => k, None => return Err(SigningError::KeyNotFound), }; @@ -154,11 +165,11 @@ impl Store { let node_info = NodeInfo { pubkey, updated_at: updated_at_std, - online: node.online, + public: node.public, }; if let Some(mut old_node_info) = self.update_node(node.ip, node_info.clone()).await { - if !node_info.online { - old_node_info.online = false; + if !node_info.public { + old_node_info.public = false; } match old_node_info.ne(&node_info) { true => { @@ -178,15 +189,10 @@ impl Store { nodes.insert(ip, info.clone()) } - pub async fn remove_node(&self, ip: &str) { - let mut nodes = self.nodes.lock().await; - nodes.remove(ip); - } - - pub async fn get_pubkey(&self, ip: &str) -> Option { - let nodes = self.nodes.lock().await; - nodes.get(ip).cloned() - } + // pub async fn remove_node(&self, ip: &str) { + // let mut nodes = self.nodes.lock().await; + // nodes.remove(ip); + // } pub async fn get_localhost(&self) -> NodeUpdate { // these unwrap never fail @@ -200,7 +206,7 @@ impl Store { ip: "localhost".to_string(), keypair: hex::encode(keypair.as_bytes()), updated_at: Some(prost_types::Timestamp::from(node_info.updated_at)), - online: false, + public: false, } } @@ -212,13 +218,13 @@ impl Store { let pubkey = keypair_raw.verifying_key(); let ip = "localhost".to_string(); let updated_at = std::time::SystemTime::now(); - let online = false; + let public = false; self.update_node( ip.clone(), NodeInfo { pubkey, updated_at, - online, + public, }, ) .await; @@ -228,7 +234,7 @@ impl Store { ip, keypair, updated_at, - online, + public, } } @@ -242,42 +248,29 @@ impl Store { ip: ip.to_string(), keypair: hex::encode(signing_key.as_bytes()), updated_at: Some(prost_types::Timestamp::from(node_info.updated_at)), - online: node_info.online, + public: node_info.public, }) }) .collect() } - /// you can specify the online argument to get only nodes that are online - pub async fn get_random_nodes(&self, online: bool) -> Vec { + // returns a random node that does not have an active connection + pub async fn get_random_node(&self) -> Option { use rand::rngs::OsRng; use rand::RngCore; let nodes = self.nodes.lock().await; + let conns = self.conns.lock().await; let len = nodes.len(); if len == 0 { - return Vec::new(); + return None; } let skip = OsRng.next_u64().try_into().unwrap_or(0) % len; - let mut iter = nodes.iter().cycle().skip(skip); - let mut random_nodes = vec![]; - let mut count = 0; - let mut iterations = 0; - while count < 3 && iterations < len { - if let Some((ip, info)) = iter.next() { - if online || info.online { - random_nodes.push(ip.clone()); - count -= 1; - } - iterations += 1; - } - } - random_nodes - } - - pub async fn set_online(&self, ip: &str, online: bool) { - let mut nodes = self.nodes.lock().await; - if let Some(node) = nodes.get_mut(ip) { - node.online = online; - } + nodes + .keys() + .cycle() + .skip(skip) + .filter(|k| !conns.contains(*k)) + .next() + .cloned() } } diff --git a/src/grpc/client.rs b/src/grpc/client.rs index 83d58ae..d6b7bbc 100644 --- a/src/grpc/client.rs +++ b/src/grpc/client.rs @@ -1,52 +1,65 @@ -#![allow(dead_code)] use super::challenge::NodeUpdate; use crate::datastore::Store; use crate::grpc::challenge::update_client::UpdateClient; -use std::fs::File; -use std::io::{BufRead, BufReader}; use std::sync::Arc; use tokio::sync::broadcast::Sender; -use tokio::task::JoinSet; +use tokio::time::{sleep, Duration}; use tokio_stream::wrappers::BroadcastStream; use tokio_stream::StreamExt; -struct Connection { - ds: Arc, - tx: Sender, -} - #[derive(Clone)] -struct ConnManager { +pub struct ConnManager { ds: Arc, tx: Sender, } impl ConnManager { - fn init(ds: Arc, tx: Sender) -> Self { + pub fn init(ds: Arc, tx: Sender) -> Self { Self { ds, tx } } - async fn connect(self, node_ip: String) { + pub async fn start_with_node(self, node_ip: String) { + self.connect_wrapper(node_ip).await; + } + + pub async fn start(self) { + loop { + if let Some(node) = self.ds.get_random_node().await { + if node != "localhost" { + self.connect_wrapper(node.clone()).await; + } + } + sleep(Duration::from_secs(3)).await; + } + } + + async fn connect_wrapper(&self, node_ip: String) { + let ds = self.ds.clone(); + ds.add_conn(&node_ip).await; + if let Err(e) = self.connect(node_ip.clone()).await { + println!("Client connection for {node_ip} failed: {e:?}"); + } + ds.delete_conn(&node_ip).await; + } + + async fn connect(&self, node_ip: String) -> Result<(), Box> { println!("Connecting to {node_ip}..."); - let mut client = UpdateClient::connect(format!("http://{node_ip}:31373")) - .await - .unwrap(); + let mut client = UpdateClient::connect(format!("http://{node_ip}:31373")).await?; let rx = self.tx.subscribe(); let rx_stream = BroadcastStream::new(rx).filter_map(|n| n.ok()); - let response = client.get_updates(rx_stream).await.unwrap(); + let response = client.get_updates(rx_stream).await?; let mut resp_stream = response.into_inner(); let _ = self.tx.send(self.ds.get_localhost().await); - while let Some(mut update) = resp_stream.message().await.unwrap() { - println!("Received message"); + while let Some(mut update) = resp_stream.message().await? { // "localhost" IPs need to be changed to the real IP of the counterpart if update.ip == "localhost" { update.ip = node_ip.clone(); // since we are connecting TO this server, we have a guarantee that this - // server is not behind NAT, so we can set it online - update.online = true; + // server is not behind NAT, so we can set it public + update.public = true; } // update the entire network in case the information is new @@ -56,32 +69,23 @@ impl ConnManager { } }; } + + Ok(()) } } -// this must panic on failure; app can't start without init nodes -fn load_init_nodes(path: &str) -> Vec { - let input = File::open(path).unwrap(); - let buffered = BufReader::new(input); - let mut ips = Vec::new(); - for line in buffered.lines() { - ips.push(line.unwrap()); - } - ips -} - -pub async fn init_connections(ds: Arc, tx: Sender) { - let mut nodes = load_init_nodes("detee_challenge_nodes"); - // we rotate online and offline nodes, to constantly check new nodes - let mut only_online_nodes = true; - loop { - let mut set = JoinSet::new(); - for node in nodes { - let conn = ConnManager::init(ds.clone(), tx.clone()); - set.spawn(conn.connect(node)); - } - while let Some(_) = set.join_next().await {} - nodes = ds.get_random_nodes(only_online_nodes).await; - only_online_nodes = !only_online_nodes; - } -} +// pub async fn init_connections(ds: Arc, tx: Sender) { +// let mut nodes = load_init_nodes("detee_challenge_nodes"); +// // we rotate online and offline nodes, to constantly check new nodes +// let mut only_online_nodes = true; +// loop { +// let mut set = JoinSet::new(); +// for node in nodes { +// let conn = ConnManager::init(ds.clone(), tx.clone()); +// set.spawn(conn.connect_wrapper(node)); +// } +// while let Some(_) = set.join_next().await {} +// nodes = ds.get_random_nodes(only_online_nodes).await; +// only_online_nodes = !only_online_nodes; +// } +// } diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 99d95e0..7052394 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -75,7 +75,6 @@ impl Update for MyServer { } } Ok(update) = rx.recv() => { - println!("Sending message."); yield Ok(update); // disconnect client if too many connections are active if tx.receiver_count() > 9 { diff --git a/src/main.rs b/src/main.rs index 94a5d47..7e914af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ use tokio::task::JoinSet; mod grpc; mod http_server; use crate::datastore::Store; +use std::fs::File; +use std::io::{BufRead, BufReader}; use std::sync::Arc; use tokio::sync::broadcast; @@ -13,13 +15,34 @@ async fn main() { ds.reset_localhost_keys().await; - let mut join_set = JoinSet::new(); + let mut long_term_tasks = JoinSet::new(); + let mut init_tasks = JoinSet::new(); - join_set.spawn(http_server::init(ds.clone())); - join_set.spawn(grpc::server::MyServer::init(ds.clone(), tx.clone()).start()); - join_set.spawn(grpc::client::init_connections(ds.clone(), tx.clone())); + long_term_tasks.spawn(http_server::init(ds.clone())); + long_term_tasks.spawn(grpc::server::MyServer::init(ds.clone(), tx.clone()).start()); + + let input = File::open("detee_challenge_nodes").unwrap(); + let buffered = BufReader::new(input); + for line in buffered.lines() { + init_tasks.spawn( + grpc::client::ConnManager::init(ds.clone(), tx.clone()).start_with_node(line.unwrap()), + ); + } + + let mut connection_count = 0; + while init_tasks.join_next().await.is_some() { + if connection_count < 3 { + long_term_tasks.spawn(grpc::client::ConnManager::init(ds.clone(), tx.clone()).start()); + connection_count += 1; + } + } + + while connection_count < 3 { + long_term_tasks.spawn(grpc::client::ConnManager::init(ds.clone(), tx.clone()).start()); + connection_count += 1; + } // exit no matter which task finished - join_set.join_next().await; + long_term_tasks.join_next().await; println!("Shutting down..."); }