pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } use hyper::server::conn::http2::Builder; use hyper_util::{ rt::{TokioExecutor, TokioIo}, service::TowerToHyperService, }; use pb::{EchoRequest, EchoResponse}; use std::io::Write; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{ rustls::{pki_types::CertificateDer, ServerConfig}, TlsAcceptor, }; use tonic::{body::boxed, service::Routes, Request, Response, Status}; use tower::ServiceBuilder; use tower::ServiceExt; use occlum_ratls::prelude::*; use occlum_ratls::RaTlsConfigBuilder; use std::sync::atomic::{AtomicUsize, Ordering}; static COUNTER: AtomicUsize = AtomicUsize::new(0); #[tokio::main] async fn main() -> Result<(), Box> { env_logger::init_from_env(env_logger::Env::default().default_filter_or("trace")); // let mrsigner_hex = "83E8A0C3ED045D9747ADE06C3BFC70FCA661A4A65FF79A800223621162A88B76"; // let mut mrsigner = [0u8; 32]; // hex::decode_to_slice(mrsigner_hex, &mut mrsigner).expect("mrsigner decoding failed"); let config = RaTlsConfig::new().allow_instance_measurement( // InstanceMeasurement::new().with_mrsigners(vec![mrsigner]) InstanceMeasurement::new().load_mr_signer_from_processor()?, ); let mut tls = ServerConfig::from_ratls_config(config) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?; tls.alpn_protocols = vec![b"h2".to_vec()]; let server = EchoServer::default(); let svc = Routes::new(pb::echo_server::EchoServer::new(server)).prepare(); let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("[::1]:50051").await?; let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); loop { let (conn, addr) = match listener.accept().await { Ok(incoming) => incoming, Err(e) => { eprintln!("Error accepting connection: {}", e); continue; } }; let http = http.clone(); let tls_acceptor = tls_acceptor.clone(); let svc = svc.clone(); tokio::spawn(async move { let mut certificates = Vec::new(); let conn = tls_acceptor .accept_with(conn, |info| { if let Some(certs) = info.peer_certificates() { for cert in certs { certificates.push(cert.clone()); } } }) .await .unwrap(); let extension_layer = tower_http::add_extension::AddExtensionLayer::new(Arc::new(ConnInfo { addr, certificates, })); let svc = ServiceBuilder::new().layer(extension_layer).service(svc); http.serve_connection( TokioIo::new(conn), TowerToHyperService::new(svc.map_request(|req: hyper::Request<_>| req.map(boxed))), ) .await .unwrap(); }); } } #[derive(Debug)] struct ConnInfo { addr: std::net::SocketAddr, certificates: Vec>, } type EchoResult = Result, Status>; #[derive(Default)] pub struct EchoServer {} #[tonic::async_trait] impl pb::echo_server::Echo for EchoServer { async fn unary_echo(&self, request: Request) -> EchoResult { let conn_info = request.extensions().get::>().unwrap(); println!( "Got a request from: {:?} with certs: {:?}", conn_info.addr, conn_info.certificates ); let count = COUNTER.fetch_add(1, Ordering::SeqCst); let mut file = std::fs::File::create(format!("test{count}.txt"))?; let content = format!("Hello, world! {count}"); file.write_all(content.as_bytes())?; let message = request.into_inner().message; Ok(Response::new(EchoResponse { message })) } }