hacker-challenge-sgx-general/src/grpc/server.rs
Valentyn Faychuk 01e90f874c
connections stability fixes
Signed-off-by: Valentyn Faychuk <valy@detee.ltd>
2024-12-24 17:46:39 +00:00

216 lines
8.0 KiB
Rust

use super::{
challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate},
InternalNodeUpdate,
};
use crate::{datastore::State, grpc::challenge::update_server::Update};
use detee_sgx::RaTlsConfig;
use rustls::pki_types::CertificateDer;
use std::{pin::Pin, sync::Arc};
use tokio::sync::broadcast::Sender;
use tokio::time::interval;
use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
pub async fn grpc_new_server(
state: Arc<State>,
keys: Keys,
ra_cfg: RaTlsConfig,
tx: Sender<InternalNodeUpdate>,
) {
NodeServer::init(state, keys, ra_cfg, tx).start().await
}
pub struct NodeServer {
state: Arc<State>,
tx: Sender<InternalNodeUpdate>,
ra_cfg: RaTlsConfig,
keys: Keys, // For sending secret keys to new nodes ;)
}
impl NodeServer {
pub fn init(
state: Arc<State>,
keys: Keys,
ra_cfg: RaTlsConfig,
tx: Sender<InternalNodeUpdate>,
) -> Self {
Self { state, tx, keys, ra_cfg }
}
pub async fn start(self) {
use detee_sgx::RaTlsConfigBuilder;
use hyper::server::conn::http2::Builder;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
service::TowerToHyperService,
};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
use tonic::{body::boxed, service::Routes};
use tower::{ServiceBuilder, ServiceExt};
// TODO: error handling, shouldn't have expects
let mut tls = ServerConfig::from_ratls_config(self.ra_cfg.clone()).unwrap();
tls.alpn_protocols = vec![b"h2".to_vec()];
let state = self.state.clone();
let svc = Routes::new(UpdateServer::new(self)).prepare();
let http = Builder::new(TokioExecutor::new());
let listener = TcpListener::bind("0.0.0.0:31373").await.expect("failed to bind listener");
let tls_acceptor = TlsAcceptor::from(Arc::new(tls));
loop {
let (conn, addr) = match listener.accept().await {
Ok(incoming) => incoming,
Err(e) => {
println!("Error accepting connection: {}", e);
continue;
}
};
let http = http.clone();
let tls_acceptor = tls_acceptor.clone();
let svc = svc.clone();
state.declare_myself_public().await;
let state = state.clone();
tokio::spawn(async move {
let mut certificates = Vec::new();
let conn = tls_acceptor
.accept_with(conn, |info| {
if let Some(certs) = info.peer_certificates() {
for cert in certs {
certificates.push(cert.clone());
}
}
})
.await;
let conn = if let Err(e) = conn {
println!("Error accepting TLS connection: {e}");
if e.to_string().contains("HandshakeFailure") {
state.increase_net_attacks().await;
}
return;
} else {
conn.unwrap()
};
let extension_layer =
tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo {
addr,
certificates,
}));
let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
let ip = addr.ip().to_string();
state.add_conn(&ip).await;
if let Err(e) = http
.serve_connection(
TokioIo::new(conn),
TowerToHyperService::new(
svc.map_request(|req: hyper::Request<_>| req.map(boxed)),
),
)
.await
{
println!("Error serving connection: {}", e);
}
state.delete_conn(&ip).await;
});
}
}
}
#[derive(Debug)]
struct ConnInfo {
addr: std::net::SocketAddr,
#[allow(dead_code)]
certificates: Vec<CertificateDer<'static>>,
}
#[tonic::async_trait]
impl Update for NodeServer {
async fn get_keys(&self, _request: Request<Empty>) -> Result<Response<Keys>, Status> {
Ok(Response::new(self.keys.clone()))
}
type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>;
async fn get_updates(
&self,
req: Request<Streaming<NodeUpdate>>,
) -> Result<Response<Self::GetUpdatesStream>, Status> {
// connection info is added in the tower for tls connection and must be present
let conn_info = req.extensions().get::<Arc<ConnInfo>>().unwrap();
let remote_ip = conn_info.addr.ip().to_string();
let tx = self.tx.clone();
let mut rx = self.tx.subscribe();
let mut inbound = req.into_inner();
let state = self.state.clone();
let my_ip = self.state.get_my_ip().await;
let stream = async_stream::stream! {
let known_nodes: Vec<NodeUpdate> = state.get_nodes().await.into_iter().map(Into::into).collect();
for update in known_nodes {
yield Ok(update);
}
let error_status: Status; // Gets initialized inside loop
let mut timeout = interval(state.get_timeout());
loop {
tokio::select! {
Some(msg) = inbound.next() => {
timeout = interval(state.get_timeout());
match msg {
Ok(update) => {
if update.ip == remote_ip {
println!("Node {remote_ip} is sending it's own update");
} else if update.ip == my_ip {
println!("Node {remote_ip} is forwarding our past update");
} else {
println!("Node {remote_ip} is forwarding the update of {}", update.ip);
}
if state.process_node_update(update.clone().into()).await {
// If process update returns true, the update must be forwarded
if tx.send((remote_ip.clone(), update).into()).is_err() {
println!("Tokio broadcast receivers had an issue consuming the channel");
};
}
}
Err(e) => {
error_status = Status::internal(format!("Error receiving client stream: {}", e));
break;
}
}
}
Ok(update) = rx.recv() => {
if update.sender_ip != remote_ip {
// Don't bounce back the update we just received
yield Ok(update.update);
}
// TODO: check if disconnect client if too many connections are active
if tx.receiver_count() > 9 {
error_status = Status::internal("Already have too many clients. Connect to another server.");
break;
}
}
_ = timeout.tick() => {
error_status = Status::internal(format!("Disconnecting after {}s timeout", state.get_timeout().as_secs()));
break;
}
}
}
yield Err(error_status);
};
Ok(Response::new(Box::pin(stream) as Self::GetUpdatesStream))
}
}