hacker-challenge-sgx-general/src/grpc/server.rs
Valentyn Faychuk 3c93b258f5
write metrics to the /host/logs
Signed-off-by: Valentyn Faychuk <valy@detee.ltd>
2024-12-02 03:42:38 +02:00

173 lines
6.3 KiB
Rust

#![allow(dead_code)]
use super::challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate};
use crate::{datastore::Store, grpc::challenge::update_server::Update};
use std::{pin::Pin, sync::Arc};
use tokio::sync::broadcast::Sender;
use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
pub struct MyServer {
ds: Arc<Store>,
tx: Sender<NodeUpdate>,
}
impl MyServer {
pub fn init(ds: Arc<Store>, tx: Sender<NodeUpdate>) -> Self {
Self { ds, tx }
}
pub async fn start(self) {
use hyper::server::conn::http2::Builder;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
service::TowerToHyperService,
};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
use tonic::{body::boxed, service::Routes};
use tower::{ServiceBuilder, ServiceExt};
use detee_sgx::{prelude::*, RaTlsConfigBuilder};
// TODO: ratls config should be global
// TODO: error handling, shouldn't have expects
let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76";
let mrsigner =
crate::sgx::mrsigner_from_hex(mrsigner_hex).expect("mrsigner decoding failed");
let config = RaTlsConfig::new()
.allow_instance_measurement(InstanceMeasurement::new().with_mrsigners(vec![mrsigner]));
let mut tls = ServerConfig::from_ratls_config(config)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))
.expect("failed to create server config");
tls.alpn_protocols = vec![b"h2".to_vec()];
let svc = Routes::new(UpdateServer::new(self)).prepare();
let http = Builder::new(TokioExecutor::new());
let listener = TcpListener::bind("0.0.0.0:31373").await.expect("failed to bind listener");
let tls_acceptor = TlsAcceptor::from(Arc::new(tls));
loop {
let (conn, addr) = match listener.accept().await {
Ok(incoming) => incoming,
Err(e) => {
eprintln!("Error accepting connection: {}", e);
continue;
}
};
let http = http.clone();
let tls_acceptor = tls_acceptor.clone();
let svc = svc.clone();
tokio::spawn(async move {
let mut certificates = Vec::new();
let conn = tls_acceptor
.accept_with(conn, |info| {
if let Some(certs) = info.peer_certificates() {
for cert in certs {
certificates.push(cert.clone());
}
}
})
.await
.unwrap();
#[derive(Debug)]
pub struct ConnInfo {
pub addr: std::net::SocketAddr,
pub certificates: Vec<rustls::pki_types::CertificateDer<'static>>,
}
let extension_layer =
tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo {
addr,
certificates,
}));
let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
http.serve_connection(
TokioIo::new(conn),
TowerToHyperService::new(
svc.map_request(|req: hyper::Request<_>| req.map(boxed)),
),
)
.await
.unwrap();
});
}
}
}
#[tonic::async_trait]
impl Update for MyServer {
type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>;
async fn get_updates(
&self,
req: Request<Streaming<NodeUpdate>>,
) -> Result<Response<Self::GetUpdatesStream>, Status> {
self.ds.increase_mratls_conns();
let remote_ip = req.remote_addr().unwrap().ip().to_string();
let tx = self.tx.clone();
let mut rx = self.tx.subscribe();
let mut inbound = req.into_inner();
let ds = self.ds.clone();
let stream = async_stream::stream! {
let full_update_list: Vec<NodeUpdate> = ds.get_node_list().into_iter().map(Into::<NodeUpdate>::into).collect();
for update in full_update_list {
yield Ok(update);
}
loop {
tokio::select! {
Some(msg) = inbound.next() => {
match msg {
Ok(mut update) => {
if update.ip == "localhost" {
update.ip = remote_ip.clone();
// note that we don't set this node online,
// as it can be behind NAT
}
if update.ip != "127.0.0.1" && ds.process_node_update(update.clone().into()) && tx.send(update.clone()).is_err() {
println!("tokio broadcast receivers had an issue consuming the channel");
};
}
Err(e) => {
yield Err(Status::internal(format!("Error receiving client stream: {}", e)));
break;
}
}
}
Ok(update) = rx.recv() => {
yield Ok(update);
// disconnect client if too many connections are active
if tx.receiver_count() > 9 {
yield Err(Status::internal("Already have too many clients. Connect to another server."));
return;
}
}
}
}
};
Ok(Response::new(Box::pin(stream) as Self::GetUpdatesStream))
}
async fn get_keys(&self, _request: Request<Empty>) -> Result<Response<Keys>, Status> {
let reply = Keys {
keypair: self.ds.get_keypair_bytes(),
token_address: self.ds.get_token_address(),
};
Ok(Response::new(reply))
}
}