detee-sgx/examples/mratls_grpcs_server.rs

127 lines
4.0 KiB
Rust

pub mod pb {
tonic::include_proto!("/grpc.examples.unaryecho");
}
use hyper::server::conn::http2::Builder;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
service::TowerToHyperService,
};
use pb::{EchoRequest, EchoResponse};
use std::io::Write;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio_rustls::{
rustls::{pki_types::CertificateDer, ServerConfig},
TlsAcceptor,
};
use tonic::{body::boxed, service::Routes, Request, Response, Status};
use tower::ServiceBuilder;
use tower::ServiceExt;
use occlum_ratls::prelude::*;
use occlum_ratls::RaTlsConfigBuilder;
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace"));
// let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76";
// let mut mrsigner = [0u8; 32];
// hex::decode_to_slice(mrsigner_hex, &mut mrsigner).expect("mrsigner decoding failed");
let config = RaTlsConfig::new().allow_instance_measurement(
// InstanceMeasurement::new().with_mrsigners(vec![mrsigner])
InstanceMeasurement::new().with_current_mrsigner()?,
);
let mut tls = ServerConfig::from_ratls_config(config)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?;
tls.alpn_protocols = vec![b"h2".to_vec()];
let server = EchoServer::default();
let svc = Routes::new(pb::echo_server::EchoServer::new(server)).prepare();
let http = Builder::new(TokioExecutor::new());
let listener = TcpListener::bind("[::1]:50051").await?;
let tls_acceptor = TlsAcceptor::from(Arc::new(tls));
loop {
let (conn, addr) = match listener.accept().await {
Ok(incoming) => incoming,
Err(e) => {
eprintln!("Error accepting connection: {}", e);
continue;
}
};
let http = http.clone();
let tls_acceptor = tls_acceptor.clone();
let svc = svc.clone();
tokio::spawn(async move {
let mut certificates = Vec::new();
let conn = tls_acceptor
.accept_with(conn, |info| {
if let Some(certs) = info.peer_certificates() {
for cert in certs {
certificates.push(cert.clone());
}
}
})
.await
.unwrap();
let extension_layer =
tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo {
addr,
certificates,
}));
let svc = ServiceBuilder::new().layer(extension_layer).service(svc);
http.serve_connection(
TokioIo::new(conn),
TowerToHyperService::new(svc.map_request(|req: hyper::Request<_>| req.map(boxed))),
)
.await
.unwrap();
});
}
}
#[derive(Debug)]
struct ConnInfo {
addr: std::net::SocketAddr,
certificates: Vec<CertificateDer<'static>>,
}
type EchoResult<T> = Result<Response<T>, Status>;
#[derive(Default)]
pub struct EchoServer {}
#[tonic::async_trait]
impl pb::echo_server::Echo for EchoServer {
async fn unary_echo(&self, request: Request<EchoRequest>) -> EchoResult<EchoResponse> {
let conn_info = request.extensions().get::<Arc<ConnInfo>>().unwrap();
println!(
"Got a request from: {:?} with certs: {:?}",
conn_info.addr, conn_info.certificates
);
let count = COUNTER.fetch_add(1, Ordering::SeqCst);
let mut file = std::fs::File::create(format!("test{count}.txt"))?;
let content = format!("Hello, world! {count}");
file.write_all(content.as_bytes())?;
let message = request.into_inner().message;
Ok(Response::new(EchoResponse { message }))
}
}