fixed connection management
all nodes now properly maintain 3 outbound connections
This commit is contained in:
		
							parent
							
								
									efa36737a0
								
							
						
					
					
						commit
						4d00b116f4
					
				| @ -7,7 +7,7 @@ message NodeUpdate { | |||||||
|    string ip = 1; |    string ip = 1; | ||||||
|    string keypair = 2; |    string keypair = 2; | ||||||
|    google.protobuf.Timestamp updated_at = 3; |    google.protobuf.Timestamp updated_at = 3; | ||||||
|    bool online = 4; |    bool public = 4; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| service Update { | service Update { | ||||||
|  | |||||||
| @ -1,8 +1,7 @@ | |||||||
| #![allow(dead_code)] |  | ||||||
| use crate::grpc::challenge::NodeUpdate; | use crate::grpc::challenge::NodeUpdate; | ||||||
| use ed25519_dalek::{Signer, SigningKey, VerifyingKey}; | use ed25519_dalek::{Signer, SigningKey, VerifyingKey}; | ||||||
| use rand::rngs::OsRng; | use rand::rngs::OsRng; | ||||||
| use std::collections::HashMap; | use std::collections::{HashMap, HashSet}; | ||||||
| use std::time::Duration; | use std::time::Duration; | ||||||
| use std::time::SystemTime; | use std::time::SystemTime; | ||||||
| use std::time::UNIX_EPOCH; | use std::time::UNIX_EPOCH; | ||||||
| @ -13,12 +12,13 @@ use tokio::sync::Mutex; | |||||||
| pub struct NodeInfo { | pub struct NodeInfo { | ||||||
|     pub pubkey: VerifyingKey, |     pub pubkey: VerifyingKey, | ||||||
|     pub updated_at: SystemTime, |     pub updated_at: SystemTime, | ||||||
|     pub online: bool, |     pub public: bool, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Needs to be surrounded in an Arc.
 | /// Needs to be surrounded in an Arc.
 | ||||||
| pub struct Store { | pub struct Store { | ||||||
|     nodes: Mutex<HashMap<IP, NodeInfo>>, |     nodes: Mutex<HashMap<IP, NodeInfo>>, | ||||||
|  |     conns: Mutex<HashSet<IP>>, | ||||||
|     keys: Mutex<HashMap<VerifyingKey, SigningKey>>, |     keys: Mutex<HashMap<VerifyingKey, SigningKey>>, | ||||||
| } | } | ||||||
| pub enum SigningError { | pub enum SigningError { | ||||||
| @ -57,20 +57,32 @@ impl std::fmt::Display for SigningError { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Store { | impl Store { | ||||||
|  |     // app should exit if any error happens here so unwrap() is good
 | ||||||
|     pub fn init() -> Self { |     pub fn init() -> Self { | ||||||
|         Self { |         Self { | ||||||
|             nodes: Mutex::new(HashMap::new()), |             nodes: Mutex::new(HashMap::new()), | ||||||
|             keys: Mutex::new(HashMap::new()), |             keys: Mutex::new(HashMap::new()), | ||||||
|  |             conns: Mutex::new(HashSet::new()), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     pub async fn add_conn(&self, ip: &str) { | ||||||
|  |         let mut conns = self.conns.lock().await; | ||||||
|  |         conns.insert(ip.to_string()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub async fn delete_conn(&self, ip: &str) { | ||||||
|  |         let mut conns = self.conns.lock().await; | ||||||
|  |         conns.remove(ip); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     pub async fn tabled_node_list(&self) -> String { |     pub async fn tabled_node_list(&self) -> String { | ||||||
|         #[derive(Tabled)] |         #[derive(Tabled)] | ||||||
|         struct OutputRow { |         struct OutputRow { | ||||||
|             ip: String, |             ip: String, | ||||||
|             pubkey: String, |             pubkey: String, | ||||||
|             age: u64, |             age: u64, | ||||||
|             online: bool, |             public: bool, | ||||||
|         } |         } | ||||||
|         let mut output = vec![]; |         let mut output = vec![]; | ||||||
|         for (ip, node_info) in self.nodes.lock().await.iter() { |         for (ip, node_info) in self.nodes.lock().await.iter() { | ||||||
| @ -80,12 +92,12 @@ impl Store { | |||||||
|                 .duration_since(node_info.updated_at) |                 .duration_since(node_info.updated_at) | ||||||
|                 .unwrap_or(Duration::ZERO) |                 .unwrap_or(Duration::ZERO) | ||||||
|                 .as_secs(); |                 .as_secs(); | ||||||
|             let online = node_info.online; |             let public = node_info.public; | ||||||
|             output.push(OutputRow { |             output.push(OutputRow { | ||||||
|                 ip, |                 ip, | ||||||
|                 pubkey, |                 pubkey, | ||||||
|                 age, |                 age, | ||||||
|                 online, |                 public, | ||||||
|             }); |             }); | ||||||
|         } |         } | ||||||
|         Table::new(output).to_string() |         Table::new(output).to_string() | ||||||
| @ -99,8 +111,7 @@ impl Store { | |||||||
|         let key_bytes = hex::decode(pubkey)?; |         let key_bytes = hex::decode(pubkey)?; | ||||||
|         let pubkey = VerifyingKey::from_bytes(&key_bytes.as_slice().try_into()?)?; |         let pubkey = VerifyingKey::from_bytes(&key_bytes.as_slice().try_into()?)?; | ||||||
| 
 | 
 | ||||||
|         let key_store = self.keys.lock().await; |         let signing_key = match self.get_privkey(&pubkey).await { | ||||||
|         let signing_key = match { key_store.get(&pubkey) } { |  | ||||||
|             Some(k) => k, |             Some(k) => k, | ||||||
|             None => return Err(SigningError::KeyNotFound), |             None => return Err(SigningError::KeyNotFound), | ||||||
|         }; |         }; | ||||||
| @ -154,11 +165,11 @@ impl Store { | |||||||
|         let node_info = NodeInfo { |         let node_info = NodeInfo { | ||||||
|             pubkey, |             pubkey, | ||||||
|             updated_at: updated_at_std, |             updated_at: updated_at_std, | ||||||
|             online: node.online, |             public: node.public, | ||||||
|         }; |         }; | ||||||
|         if let Some(mut old_node_info) = self.update_node(node.ip, node_info.clone()).await { |         if let Some(mut old_node_info) = self.update_node(node.ip, node_info.clone()).await { | ||||||
|             if !node_info.online { |             if !node_info.public { | ||||||
|                 old_node_info.online = false; |                 old_node_info.public = false; | ||||||
|             } |             } | ||||||
|             match old_node_info.ne(&node_info) { |             match old_node_info.ne(&node_info) { | ||||||
|                 true => { |                 true => { | ||||||
| @ -178,15 +189,10 @@ impl Store { | |||||||
|         nodes.insert(ip, info.clone()) |         nodes.insert(ip, info.clone()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub async fn remove_node(&self, ip: &str) { |     //    pub async fn remove_node(&self, ip: &str) {
 | ||||||
|         let mut nodes = self.nodes.lock().await; |     //        let mut nodes = self.nodes.lock().await;
 | ||||||
|         nodes.remove(ip); |     //        nodes.remove(ip);
 | ||||||
|     } |     //    }
 | ||||||
| 
 |  | ||||||
|     pub async fn get_pubkey(&self, ip: &str) -> Option<NodeInfo> { |  | ||||||
|         let nodes = self.nodes.lock().await; |  | ||||||
|         nodes.get(ip).cloned() |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     pub async fn get_localhost(&self) -> NodeUpdate { |     pub async fn get_localhost(&self) -> NodeUpdate { | ||||||
|         // these unwrap never fail
 |         // these unwrap never fail
 | ||||||
| @ -200,7 +206,7 @@ impl Store { | |||||||
|             ip: "localhost".to_string(), |             ip: "localhost".to_string(), | ||||||
|             keypair: hex::encode(keypair.as_bytes()), |             keypair: hex::encode(keypair.as_bytes()), | ||||||
|             updated_at: Some(prost_types::Timestamp::from(node_info.updated_at)), |             updated_at: Some(prost_types::Timestamp::from(node_info.updated_at)), | ||||||
|             online: false, |             public: false, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -212,13 +218,13 @@ impl Store { | |||||||
|         let pubkey = keypair_raw.verifying_key(); |         let pubkey = keypair_raw.verifying_key(); | ||||||
|         let ip = "localhost".to_string(); |         let ip = "localhost".to_string(); | ||||||
|         let updated_at = std::time::SystemTime::now(); |         let updated_at = std::time::SystemTime::now(); | ||||||
|         let online = false; |         let public = false; | ||||||
|         self.update_node( |         self.update_node( | ||||||
|             ip.clone(), |             ip.clone(), | ||||||
|             NodeInfo { |             NodeInfo { | ||||||
|                 pubkey, |                 pubkey, | ||||||
|                 updated_at, |                 updated_at, | ||||||
|                 online, |                 public, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         .await; |         .await; | ||||||
| @ -228,7 +234,7 @@ impl Store { | |||||||
|             ip, |             ip, | ||||||
|             keypair, |             keypair, | ||||||
|             updated_at, |             updated_at, | ||||||
|             online, |             public, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -242,42 +248,29 @@ impl Store { | |||||||
|                     ip: ip.to_string(), |                     ip: ip.to_string(), | ||||||
|                     keypair: hex::encode(signing_key.as_bytes()), |                     keypair: hex::encode(signing_key.as_bytes()), | ||||||
|                     updated_at: Some(prost_types::Timestamp::from(node_info.updated_at)), |                     updated_at: Some(prost_types::Timestamp::from(node_info.updated_at)), | ||||||
|                     online: node_info.online, |                     public: node_info.public, | ||||||
|                 }) |                 }) | ||||||
|             }) |             }) | ||||||
|             .collect() |             .collect() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// you can specify the online argument to get only nodes that are online
 |     // returns a random node that does not have an active connection
 | ||||||
|     pub async fn get_random_nodes(&self, online: bool) -> Vec<String> { |     pub async fn get_random_node(&self) -> Option<String> { | ||||||
|         use rand::rngs::OsRng; |         use rand::rngs::OsRng; | ||||||
|         use rand::RngCore; |         use rand::RngCore; | ||||||
|         let nodes = self.nodes.lock().await; |         let nodes = self.nodes.lock().await; | ||||||
|  |         let conns = self.conns.lock().await; | ||||||
|         let len = nodes.len(); |         let len = nodes.len(); | ||||||
|         if len == 0 { |         if len == 0 { | ||||||
|             return Vec::new(); |             return None; | ||||||
|         } |         } | ||||||
|         let skip = OsRng.next_u64().try_into().unwrap_or(0) % len; |         let skip = OsRng.next_u64().try_into().unwrap_or(0) % len; | ||||||
|         let mut iter = nodes.iter().cycle().skip(skip); |         nodes | ||||||
|         let mut random_nodes = vec![]; |             .keys() | ||||||
|         let mut count = 0; |             .cycle() | ||||||
|         let mut iterations = 0; |             .skip(skip) | ||||||
|         while count < 3 && iterations < len { |             .filter(|k| !conns.contains(*k)) | ||||||
|             if let Some((ip, info)) = iter.next() { |             .next() | ||||||
|                 if online || info.online { |             .cloned() | ||||||
|                     random_nodes.push(ip.clone()); |  | ||||||
|                     count -= 1; |  | ||||||
|                 } |  | ||||||
|                 iterations += 1; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         random_nodes |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     pub async fn set_online(&self, ip: &str, online: bool) { |  | ||||||
|         let mut nodes = self.nodes.lock().await; |  | ||||||
|         if let Some(node) = nodes.get_mut(ip) { |  | ||||||
|             node.online = online; |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,52 +1,65 @@ | |||||||
| #![allow(dead_code)] |  | ||||||
| use super::challenge::NodeUpdate; | use super::challenge::NodeUpdate; | ||||||
| use crate::datastore::Store; | use crate::datastore::Store; | ||||||
| use crate::grpc::challenge::update_client::UpdateClient; | use crate::grpc::challenge::update_client::UpdateClient; | ||||||
| use std::fs::File; |  | ||||||
| use std::io::{BufRead, BufReader}; |  | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
| use tokio::sync::broadcast::Sender; | use tokio::sync::broadcast::Sender; | ||||||
| use tokio::task::JoinSet; | use tokio::time::{sleep, Duration}; | ||||||
| use tokio_stream::wrappers::BroadcastStream; | use tokio_stream::wrappers::BroadcastStream; | ||||||
| use tokio_stream::StreamExt; | use tokio_stream::StreamExt; | ||||||
| 
 | 
 | ||||||
| struct Connection { |  | ||||||
|     ds: Arc<Store>, |  | ||||||
|     tx: Sender<NodeUpdate>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Clone)] | #[derive(Clone)] | ||||||
| struct ConnManager { | pub struct ConnManager { | ||||||
|     ds: Arc<Store>, |     ds: Arc<Store>, | ||||||
|     tx: Sender<NodeUpdate>, |     tx: Sender<NodeUpdate>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl ConnManager { | impl ConnManager { | ||||||
|     fn init(ds: Arc<Store>, tx: Sender<NodeUpdate>) -> Self { |     pub fn init(ds: Arc<Store>, tx: Sender<NodeUpdate>) -> Self { | ||||||
|         Self { ds, tx } |         Self { ds, tx } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     async fn connect(self, node_ip: String) { |     pub async fn start_with_node(self, node_ip: String) { | ||||||
|  |         self.connect_wrapper(node_ip).await; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub async fn start(self) { | ||||||
|  |         loop { | ||||||
|  |             if let Some(node) = self.ds.get_random_node().await { | ||||||
|  |                 if node != "localhost" { | ||||||
|  |                     self.connect_wrapper(node.clone()).await; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             sleep(Duration::from_secs(3)).await; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     async fn connect_wrapper(&self, node_ip: String) { | ||||||
|  |         let ds = self.ds.clone(); | ||||||
|  |         ds.add_conn(&node_ip).await; | ||||||
|  |         if let Err(e) = self.connect(node_ip.clone()).await { | ||||||
|  |             println!("Client connection for {node_ip} failed: {e:?}"); | ||||||
|  |         } | ||||||
|  |         ds.delete_conn(&node_ip).await; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     async fn connect(&self, node_ip: String) -> Result<(), Box<dyn std::error::Error>> { | ||||||
|         println!("Connecting to {node_ip}..."); |         println!("Connecting to {node_ip}..."); | ||||||
|         let mut client = UpdateClient::connect(format!("http://{node_ip}:31373")) |         let mut client = UpdateClient::connect(format!("http://{node_ip}:31373")).await?; | ||||||
|             .await |  | ||||||
|             .unwrap(); |  | ||||||
| 
 | 
 | ||||||
|         let rx = self.tx.subscribe(); |         let rx = self.tx.subscribe(); | ||||||
|         let rx_stream = BroadcastStream::new(rx).filter_map(|n| n.ok()); |         let rx_stream = BroadcastStream::new(rx).filter_map(|n| n.ok()); | ||||||
|         let response = client.get_updates(rx_stream).await.unwrap(); |         let response = client.get_updates(rx_stream).await?; | ||||||
|         let mut resp_stream = response.into_inner(); |         let mut resp_stream = response.into_inner(); | ||||||
| 
 | 
 | ||||||
|         let _ = self.tx.send(self.ds.get_localhost().await); |         let _ = self.tx.send(self.ds.get_localhost().await); | ||||||
| 
 | 
 | ||||||
|         while let Some(mut update) = resp_stream.message().await.unwrap() { |         while let Some(mut update) = resp_stream.message().await? { | ||||||
|             println!("Received message"); |  | ||||||
|             // "localhost" IPs need to be changed to the real IP of the counterpart
 |             // "localhost" IPs need to be changed to the real IP of the counterpart
 | ||||||
|             if update.ip == "localhost" { |             if update.ip == "localhost" { | ||||||
|                 update.ip = node_ip.clone(); |                 update.ip = node_ip.clone(); | ||||||
|                 // since we are connecting TO this server, we have a guarantee that this
 |                 // since we are connecting TO this server, we have a guarantee that this
 | ||||||
|                 // server is not behind NAT, so we can set it online
 |                 // server is not behind NAT, so we can set it public
 | ||||||
|                 update.online = true; |                 update.public = true; | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // update the entire network in case the information is new
 |             // update the entire network in case the information is new
 | ||||||
| @ -56,32 +69,23 @@ impl ConnManager { | |||||||
|                 } |                 } | ||||||
|             }; |             }; | ||||||
|         } |         } | ||||||
|  | 
 | ||||||
|  |         Ok(()) | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // this must panic on failure; app can't start without init nodes
 | // pub async fn init_connections(ds: Arc<Store>, tx: Sender<NodeUpdate>) {
 | ||||||
| fn load_init_nodes(path: &str) -> Vec<String> { | //     let mut nodes = load_init_nodes("detee_challenge_nodes");
 | ||||||
|     let input = File::open(path).unwrap(); | //     // we rotate online and offline nodes, to constantly check new nodes
 | ||||||
|     let buffered = BufReader::new(input); | //     let mut only_online_nodes = true;
 | ||||||
|     let mut ips = Vec::new(); | //     loop {
 | ||||||
|     for line in buffered.lines() { | //         let mut set = JoinSet::new();
 | ||||||
|         ips.push(line.unwrap()); | //         for node in nodes {
 | ||||||
|     } | //             let conn = ConnManager::init(ds.clone(), tx.clone());
 | ||||||
|     ips | //             set.spawn(conn.connect_wrapper(node));
 | ||||||
| } | //         }
 | ||||||
| 
 | //         while let Some(_) = set.join_next().await {}
 | ||||||
| pub async fn init_connections(ds: Arc<Store>, tx: Sender<NodeUpdate>) { | //         nodes = ds.get_random_nodes(only_online_nodes).await;
 | ||||||
|     let mut nodes = load_init_nodes("detee_challenge_nodes"); | //         only_online_nodes = !only_online_nodes;
 | ||||||
|     // we rotate online and offline nodes, to constantly check new nodes
 | //     }
 | ||||||
|     let mut only_online_nodes = true; | // }
 | ||||||
|     loop { |  | ||||||
|         let mut set = JoinSet::new(); |  | ||||||
|         for node in nodes { |  | ||||||
|             let conn = ConnManager::init(ds.clone(), tx.clone()); |  | ||||||
|             set.spawn(conn.connect(node)); |  | ||||||
|         } |  | ||||||
|         while let Some(_) = set.join_next().await {} |  | ||||||
|         nodes = ds.get_random_nodes(only_online_nodes).await; |  | ||||||
|         only_online_nodes = !only_online_nodes; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  | |||||||
| @ -75,7 +75,6 @@ impl Update for MyServer { | |||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                     Ok(update) = rx.recv() => { |                     Ok(update) = rx.recv() => { | ||||||
|                         println!("Sending message."); |  | ||||||
|                         yield Ok(update); |                         yield Ok(update); | ||||||
|                         // disconnect client if too many connections are active
 |                         // disconnect client if too many connections are active
 | ||||||
|                         if tx.receiver_count() > 9 { |                         if tx.receiver_count() > 9 { | ||||||
|  | |||||||
							
								
								
									
										33
									
								
								src/main.rs
									
									
									
									
									
								
							
							
								
								
								
								
								
									
									
								
							
						
						
									
										33
									
								
								src/main.rs
									
									
									
									
									
								
							| @ -3,6 +3,8 @@ use tokio::task::JoinSet; | |||||||
| mod grpc; | mod grpc; | ||||||
| mod http_server; | mod http_server; | ||||||
| use crate::datastore::Store; | use crate::datastore::Store; | ||||||
|  | use std::fs::File; | ||||||
|  | use std::io::{BufRead, BufReader}; | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
| use tokio::sync::broadcast; | use tokio::sync::broadcast; | ||||||
| 
 | 
 | ||||||
| @ -13,13 +15,34 @@ async fn main() { | |||||||
| 
 | 
 | ||||||
|     ds.reset_localhost_keys().await; |     ds.reset_localhost_keys().await; | ||||||
| 
 | 
 | ||||||
|     let mut join_set = JoinSet::new(); |     let mut long_term_tasks = JoinSet::new(); | ||||||
|  |     let mut init_tasks = JoinSet::new(); | ||||||
| 
 | 
 | ||||||
|     join_set.spawn(http_server::init(ds.clone())); |     long_term_tasks.spawn(http_server::init(ds.clone())); | ||||||
|     join_set.spawn(grpc::server::MyServer::init(ds.clone(), tx.clone()).start()); |     long_term_tasks.spawn(grpc::server::MyServer::init(ds.clone(), tx.clone()).start()); | ||||||
|     join_set.spawn(grpc::client::init_connections(ds.clone(), tx.clone())); | 
 | ||||||
|  |     let input = File::open("detee_challenge_nodes").unwrap(); | ||||||
|  |     let buffered = BufReader::new(input); | ||||||
|  |     for line in buffered.lines() { | ||||||
|  |         init_tasks.spawn( | ||||||
|  |             grpc::client::ConnManager::init(ds.clone(), tx.clone()).start_with_node(line.unwrap()), | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     let mut connection_count = 0; | ||||||
|  |     while init_tasks.join_next().await.is_some() { | ||||||
|  |         if connection_count < 3 { | ||||||
|  |             long_term_tasks.spawn(grpc::client::ConnManager::init(ds.clone(), tx.clone()).start()); | ||||||
|  |             connection_count += 1; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     while connection_count < 3 { | ||||||
|  |         long_term_tasks.spawn(grpc::client::ConnManager::init(ds.clone(), tx.clone()).start()); | ||||||
|  |         connection_count += 1; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // exit no matter which task finished
 |     // exit no matter which task finished
 | ||||||
|     join_set.join_next().await; |     long_term_tasks.join_next().await; | ||||||
|     println!("Shutting down..."); |     println!("Shutting down..."); | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user