use super::challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate}; use crate::{datastore::State, grpc::challenge::update_server::Update}; use detee_sgx::RaTlsConfig; use rustls::pki_types::CertificateDer; use std::{pin::Pin, sync::Arc}; use tokio::sync::broadcast::Sender; use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; pub struct MyServer { state: Arc, tx: Sender, ratls_config: RaTlsConfig, keys: Keys, // For sending secret keys to new nodes ;) } impl MyServer { pub fn init( state: Arc, keys: Keys, ratls_config: RaTlsConfig, tx: Sender, ) -> Self { Self { state, tx, keys, ratls_config } } pub async fn start(self) { use detee_sgx::RaTlsConfigBuilder; use hyper::server::conn::http2::Builder; use hyper_util::{ rt::{TokioExecutor, TokioIo}, service::TowerToHyperService, }; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; use tonic::{body::boxed, service::Routes}; use tower::{ServiceBuilder, ServiceExt}; // TODO: error handling, shouldn't have expects let mut tls = ServerConfig::from_ratls_config(self.ratls_config.clone()).unwrap(); tls.alpn_protocols = vec![b"h2".to_vec()]; let state = self.state.clone(); let svc = Routes::new(UpdateServer::new(self)).prepare(); let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("0.0.0.0:31373").await.expect("failed to bind listener"); let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); loop { let (conn, addr) = match listener.accept().await { Ok(incoming) => incoming, Err(e) => { println!("Error accepting connection: {}", e); state.increase_net_attacks(); continue; } }; let http = http.clone(); let tls_acceptor = tls_acceptor.clone(); let svc = svc.clone(); let state = state.clone(); tokio::spawn(async move { let mut certificates = Vec::new(); let conn = tls_acceptor .accept_with(conn, |info| { if let Some(certs) = info.peer_certificates() { for cert in certs { certificates.push(cert.clone()); } } }) .await; let conn = if let Err(e) = conn { println!("Error accepting TLS connection: {}", e); state.increase_net_attacks(); return; } else { conn.unwrap() }; let extension_layer = tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo { addr, certificates, })); let svc = ServiceBuilder::new().layer(extension_layer).service(svc); if let Err(e) = http .serve_connection( TokioIo::new(conn), TowerToHyperService::new( svc.map_request(|req: hyper::Request<_>| req.map(boxed)), ), ) .await { println!("Error serving connection: {}", e); } }); } } } #[derive(Debug)] struct ConnInfo { addr: std::net::SocketAddr, certificates: Vec>, } #[tonic::async_trait] impl Update for MyServer { type GetUpdatesStream = Pin> + Send>>; async fn get_keys(&self, _request: Request) -> Result, Status> { Ok(Response::new(self.keys.clone())) } async fn get_updates( &self, req: Request>, ) -> Result, Status> { let conn_info = req.extensions().get::>().unwrap(); self.state.increase_mratls_conns(); let remote_ip = conn_info.addr.ip().to_string(); let tx = self.tx.clone(); let mut rx = self.tx.subscribe(); let mut inbound = req.into_inner(); let ds = self.state.clone(); let stream = async_stream::stream! { let full_update_list: Vec = ds.get_node_list().into_iter().map(Into::::into).collect(); for update in full_update_list { yield Ok(update); } loop { tokio::select! { Some(msg) = inbound.next() => { match msg { Ok(mut update) => { if update.ip == "localhost" { update.ip = remote_ip.clone(); // note that we don't set this node online, // as it can be behind NAT } if update.ip != "127.0.0.1" && ds.process_node_update(update.clone().into()) && tx.send(update.clone()).is_err() { println!("tokio broadcast receivers had an issue consuming the channel"); }; } Err(e) => { yield Err(Status::internal(format!("Error receiving client stream: {}", e))); break; } } } Ok(update) = rx.recv() => { yield Ok(update); // disconnect client if too many connections are active if tx.receiver_count() > 9 { yield Err(Status::internal("Already have too many clients. Connect to another server.")); return; } } } } }; Ok(Response::new(Box::pin(stream) as Self::GetUpdatesStream)) } }