#![allow(dead_code)] use super::challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate}; use crate::{datastore::Store, grpc::challenge::update_server::Update}; use std::{pin::Pin, sync::Arc}; use tokio::sync::broadcast::Sender; use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; pub struct MyServer { ds: Arc, tx: Sender, } impl MyServer { pub fn init(ds: Arc, tx: Sender) -> Self { Self { ds, tx } } pub async fn start(self) { use hyper::server::conn::http2::Builder; use hyper_util::{ rt::{TokioExecutor, TokioIo}, service::TowerToHyperService, }; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; use tonic::{body::boxed, service::Routes}; use tower::{ServiceBuilder, ServiceExt}; use detee_sgx::{prelude::*, RaTlsConfigBuilder}; // TODO: ratls config should be global // TODO: error handling, shouldn't have expects let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76"; let mrsigner = crate::sgx::mrsigner_from_hex(mrsigner_hex).expect("mrsigner decoding failed"); let config = RaTlsConfig::new() .allow_instance_measurement(InstanceMeasurement::new().with_mrsigners(vec![mrsigner])); let mut tls = ServerConfig::from_ratls_config(config) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e))) .expect("failed to create server config"); tls.alpn_protocols = vec![b"h2".to_vec()]; let svc = Routes::new(UpdateServer::new(self)).prepare(); let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("0.0.0.0:31373").await.expect("failed to bind listener"); let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); loop { let (conn, addr) = match listener.accept().await { Ok(incoming) => incoming, Err(e) => { eprintln!("Error accepting connection: {}", e); continue; } }; let http = http.clone(); let tls_acceptor = tls_acceptor.clone(); let svc = svc.clone(); tokio::spawn(async move { let mut certificates = Vec::new(); let conn = tls_acceptor .accept_with(conn, |info| { if let Some(certs) = info.peer_certificates() { for cert in certs { certificates.push(cert.clone()); } } }) .await .unwrap(); #[derive(Debug)] pub struct ConnInfo { pub addr: std::net::SocketAddr, pub certificates: Vec>, } let extension_layer = tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo { addr, certificates, })); let svc = ServiceBuilder::new().layer(extension_layer).service(svc); http.serve_connection( TokioIo::new(conn), TowerToHyperService::new( svc.map_request(|req: hyper::Request<_>| req.map(boxed)), ), ) .await .unwrap(); }); } } } #[tonic::async_trait] impl Update for MyServer { type GetUpdatesStream = Pin> + Send>>; async fn get_updates( &self, req: Request>, ) -> Result, Status> { self.ds.increase_mratls_conns(); let remote_ip = req.remote_addr().unwrap().ip().to_string(); let tx = self.tx.clone(); let mut rx = self.tx.subscribe(); let mut inbound = req.into_inner(); let ds = self.ds.clone(); let stream = async_stream::stream! { let full_update_list: Vec = ds.get_node_list().into_iter().map(Into::::into).collect(); for update in full_update_list { yield Ok(update); } loop { tokio::select! { Some(msg) = inbound.next() => { match msg { Ok(mut update) => { if update.ip == "localhost" { update.ip = remote_ip.clone(); // note that we don't set this node online, // as it can be behind NAT } if update.ip != "127.0.0.1" && ds.process_node_update(update.clone().into()) && tx.send(update.clone()).is_err() { println!("tokio broadcast receivers had an issue consuming the channel"); }; } Err(e) => { yield Err(Status::internal(format!("Error receiving client stream: {}", e))); break; } } } Ok(update) = rx.recv() => { yield Ok(update); // disconnect client if too many connections are active if tx.receiver_count() > 9 { yield Err(Status::internal("Already have too many clients. Connect to another server.")); return; } } } } }; Ok(Response::new(Box::pin(stream) as Self::GetUpdatesStream)) } async fn get_keys(&self, _request: Request) -> Result, Status> { let reply = Keys { keypair: self.ds.get_keypair_bytes(), token_address: self.ds.get_token_address(), }; Ok(Response::new(reply)) } }