180 lines
6.4 KiB
Rust
180 lines
6.4 KiB
Rust
use super::challenge::{update_server::UpdateServer, Empty, Keys, NodeUpdate};
|
|
use crate::{datastore::State, grpc::challenge::update_server::Update};
|
|
use detee_sgx::RaTlsConfig;
|
|
use rustls::pki_types::CertificateDer;
|
|
use std::{pin::Pin, sync::Arc};
|
|
use tokio::sync::broadcast::Sender;
|
|
use tokio_stream::{Stream, StreamExt};
|
|
use tonic::{Request, Response, Status, Streaming};
|
|
|
|
pub struct MyServer {
|
|
state: Arc<State>,
|
|
tx: Sender<NodeUpdate>,
|
|
ratls_config: RaTlsConfig,
|
|
keys: Keys, // For sending secret keys to new nodes ;)
|
|
}
|
|
|
|
impl MyServer {
|
|
pub fn init(
|
|
state: Arc<State>,
|
|
keys: Keys,
|
|
ratls_config: RaTlsConfig,
|
|
tx: Sender<NodeUpdate>,
|
|
) -> Self {
|
|
Self { state, tx, keys, ratls_config }
|
|
}
|
|
|
|
pub async fn start(self) {
|
|
use detee_sgx::RaTlsConfigBuilder;
|
|
use hyper::server::conn::http2::Builder;
|
|
use hyper_util::{
|
|
rt::{TokioExecutor, TokioIo},
|
|
service::TowerToHyperService,
|
|
};
|
|
use std::sync::Arc;
|
|
use tokio::net::TcpListener;
|
|
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
|
|
use tonic::{body::boxed, service::Routes};
|
|
use tower::{ServiceBuilder, ServiceExt};
|
|
|
|
// TODO: error handling, shouldn't have expects
|
|
|
|
let mut tls = ServerConfig::from_ratls_config(self.ratls_config.clone()).unwrap();
|
|
tls.alpn_protocols = vec![b"h2".to_vec()];
|
|
|
|
let state = self.state.clone();
|
|
let svc = Routes::new(UpdateServer::new(self)).prepare();
|
|
|
|
let http = Builder::new(TokioExecutor::new());
|
|
|
|
let listener = TcpListener::bind("0.0.0.0:31373").await.expect("failed to bind listener");
|
|
let tls_acceptor = TlsAcceptor::from(Arc::new(tls));
|
|
|
|
loop {
|
|
let (conn, addr) = match listener.accept().await {
|
|
Ok(incoming) => incoming,
|
|
Err(e) => {
|
|
println!("Error accepting connection: {}", e);
|
|
state.increase_net_attacks();
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let http = http.clone();
|
|
let tls_acceptor = tls_acceptor.clone();
|
|
let svc = svc.clone();
|
|
|
|
let state = state.clone();
|
|
tokio::spawn(async move {
|
|
let mut certificates = Vec::new();
|
|
|
|
let conn = tls_acceptor
|
|
.accept_with(conn, |info| {
|
|
if let Some(certs) = info.peer_certificates() {
|
|
for cert in certs {
|
|
certificates.push(cert.clone());
|
|
}
|
|
}
|
|
})
|
|
.await;
|
|
|
|
let conn = if let Err(e) = conn {
|
|
println!("Error accepting TLS connection: {}", e);
|
|
state.increase_net_attacks();
|
|
return;
|
|
} else {
|
|
conn.unwrap()
|
|
};
|
|
|
|
let extension_layer =
|
|
tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo {
|
|
addr,
|
|
certificates,
|
|
}));
|
|
let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
|
|
|
|
if let Err(e) = http
|
|
.serve_connection(
|
|
TokioIo::new(conn),
|
|
TowerToHyperService::new(
|
|
svc.map_request(|req: hyper::Request<_>| req.map(boxed)),
|
|
),
|
|
)
|
|
.await
|
|
{
|
|
println!("Error serving connection: {}", e);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ConnInfo {
|
|
addr: std::net::SocketAddr,
|
|
certificates: Vec<CertificateDer<'static>>,
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl Update for MyServer {
|
|
type GetUpdatesStream = Pin<Box<dyn Stream<Item = Result<NodeUpdate, Status>> + Send>>;
|
|
|
|
async fn get_keys(&self, _request: Request<Empty>) -> Result<Response<Keys>, Status> {
|
|
Ok(Response::new(self.keys.clone()))
|
|
}
|
|
|
|
async fn get_updates(
|
|
&self,
|
|
req: Request<Streaming<NodeUpdate>>,
|
|
) -> Result<Response<Self::GetUpdatesStream>, Status> {
|
|
let conn_info = req.extensions().get::<Arc<ConnInfo>>().unwrap();
|
|
self.state.increase_mratls_conns();
|
|
let remote_ip = conn_info.addr.ip().to_string();
|
|
let tx = self.tx.clone();
|
|
let mut rx = self.tx.subscribe();
|
|
let mut inbound = req.into_inner();
|
|
let ds = self.state.clone();
|
|
|
|
let stream = async_stream::stream! {
|
|
let full_update_list: Vec<NodeUpdate> = ds.get_node_list().into_iter().map(Into::<NodeUpdate>::into).collect();
|
|
for update in full_update_list {
|
|
yield Ok(update);
|
|
}
|
|
|
|
loop {
|
|
tokio::select! {
|
|
Some(msg) = inbound.next() => {
|
|
match msg {
|
|
Ok(mut update) => {
|
|
if update.ip == "localhost" {
|
|
update.ip = remote_ip.clone();
|
|
// note that we don't set this node online,
|
|
// as it can be behind NAT
|
|
}
|
|
if update.ip != "127.0.0.1" && ds.process_node_update(update.clone().into()) && tx.send(update.clone()).is_err() {
|
|
println!("tokio broadcast receivers had an issue consuming the channel");
|
|
};
|
|
}
|
|
Err(e) => {
|
|
yield Err(Status::internal(format!("Error receiving client stream: {}", e)));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
Ok(update) = rx.recv() => {
|
|
yield Ok(update);
|
|
// disconnect client if too many connections are active
|
|
if tx.receiver_count() > 9 {
|
|
yield Err(Status::internal("Already have too many clients. Connect to another server."));
|
|
return;
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Response::new(Box::pin(stream) as Self::GetUpdatesStream))
|
|
}
|
|
}
|