use super::{ challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate}, InternalNodeUpdate, }; use crate::{datastore::State, grpc::challenge::update_server::Update}; use detee_sgx::RaTlsConfig; use rustls::pki_types::CertificateDer; use std::{pin::Pin, sync::Arc}; use tokio::{sync::broadcast::Sender, time::interval}; use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; pub async fn grpc_new_server( state: Arc, keys: Keys, ra_cfg: RaTlsConfig, tx: Sender, ) { NodeServer::init(state, keys, ra_cfg, tx).start().await } pub struct NodeServer { state: Arc, tx: Sender, ra_cfg: RaTlsConfig, keys: Keys, // For sending secret keys to new nodes ;) } impl NodeServer { pub fn init( state: Arc, keys: Keys, ra_cfg: RaTlsConfig, tx: Sender, ) -> Self { Self { state, tx, keys, ra_cfg } } pub async fn start(self) { use detee_sgx::RaTlsConfigBuilder; use hyper::server::conn::http2::Builder; use hyper_util::{ rt::{TokioExecutor, TokioIo}, service::TowerToHyperService, }; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; use tonic::{body::boxed, service::Routes}; use tower::{ServiceBuilder, ServiceExt}; // TODO: error handling, shouldn't have expects let mut tls = ServerConfig::from_ratls_config(self.ra_cfg.clone()).unwrap(); tls.alpn_protocols = vec![b"h2".to_vec()]; let state = self.state.clone(); let svc = Routes::new(UpdateServer::new(self)).prepare(); let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("0.0.0.0:31373").await.expect("failed to bind listener"); let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); loop { let (conn, addr) = match listener.accept().await { Ok(incoming) => incoming, Err(e) => { println!("Error accepting connection: {}", e); continue; } }; let http = http.clone(); let tls_acceptor = tls_acceptor.clone(); let svc = svc.clone(); state.declare_myself_public().await; let state = state.clone(); tokio::spawn(async move { let mut certificates = Vec::new(); let conn = tls_acceptor .accept_with(conn, |info| { if let Some(certs) = info.peer_certificates() { for cert in certs { certificates.push(cert.clone()); } } }) .await; let conn = if let Err(e) = conn { println!("Error accepting TLS connection: {e}"); let attack_error_messages = ["handshake", "certificate", "quote"]; let err_str = e.to_string().to_lowercase(); if attack_error_messages.iter().any(|att_er_str| err_str.contains(att_er_str)) { state.increase_net_attacks().await; } return; } else { conn.unwrap() }; let extension_layer = tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo { addr, certificates, })); let svc = ServiceBuilder::new().layer(extension_layer).service(svc); let ip = addr.ip().to_string(); state.add_conn(&ip).await; if let Err(e) = http .serve_connection( TokioIo::new(conn), TowerToHyperService::new( svc.map_request(|req: hyper::Request<_>| req.map(boxed)), ), ) .await { println!("Error serving connection: {}", e); } state.delete_conn(&ip).await; }); } } } #[derive(Debug)] struct ConnInfo { addr: std::net::SocketAddr, #[allow(dead_code)] certificates: Vec>, } #[tonic::async_trait] impl Update for NodeServer { type GetUpdatesStream = Pin> + Send>>; async fn get_keys(&self, _request: Request) -> Result, Status> { Ok(Response::new(self.keys.clone())) } async fn get_updates( &self, req: Request>, ) -> Result, Status> { // connection info is added in the tower for tls connection and must be present let conn_info = req.extensions().get::>().unwrap(); let remote_ip = conn_info.addr.ip().to_string(); let tx = self.tx.clone(); let mut rx = self.tx.subscribe(); let mut inbound = req.into_inner(); let state = self.state.clone(); let my_ip = self.state.get_my_ip().await; let stream = async_stream::stream! { let known_nodes: Vec = state.get_nodes().await.into_iter().map(Into::into).collect(); for update in known_nodes { yield Ok(update); } let error_status: Status; // Gets initialized inside loop let mut timeout = interval(state.get_timeout()); timeout.tick().await; loop { tokio::select! { Some(msg) = inbound.next() => { timeout.reset(); match msg { Ok(update) => { if update.ip == remote_ip { println!("Node {remote_ip} is sending it's own update"); } else if update.ip == my_ip { println!("Node {remote_ip} is forwarding our past update"); } else { println!("Node {remote_ip} is forwarding the update of {}", update.ip); } if state.process_node_update(update.clone().into()).await { // If process update returns true, the update must be forwarded if tx.send((remote_ip.clone(), update).into()).is_err() { println!("Tokio broadcast receivers had an issue consuming the channel"); }; } } Err(e) => { error_status = Status::internal(format!("Error receiving client stream: {}", e)); break; } } } Ok(update) = rx.recv() => { if update.sender_ip != remote_ip { // Don't bounce back the update we just received yield Ok(update.update); } // TODO: check if disconnect client if too many connections are active // Its tested and working if tx.receiver_count() > 9 { error_status = Status::internal("Already have too many clients. Connect to another server."); break; } } _ = timeout.tick() => { error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs())); break; } } } yield Err(error_status); }; Ok(Response::new(Box::pin(stream) as Self::GetUpdatesStream)) } }