From 8abe8eebf45d4beea15fe90b772459b98bafaebb Mon Sep 17 00:00:00 2001 From: Noor Date: Fri, 2 May 2025 19:33:34 +0530 Subject: [PATCH] Refactor database connection handling use SurrealDB directly in Brain services modified tests accordingly cleanedup tests comments --- src/bin/brain.rs | 11 ++-- src/grpc.rs | 107 ++++++++++++------------------- src/lib.rs | 6 -- tests/common/prepare_test_env.rs | 74 +++++---------------- 4 files changed, 60 insertions(+), 138 deletions(-) diff --git a/src/bin/brain.rs b/src/bin/brain.rs index 5d6abdd..d11360d 100644 --- a/src/bin/brain.rs +++ b/src/bin/brain.rs @@ -9,22 +9,21 @@ use surreal_brain::constants::{ use surreal_brain::db; use surreal_brain::grpc::BrainVmCliForReal; use surreal_brain::grpc::{BrainGeneralCliForReal, BrainVmDaemonForReal}; -use surreal_brain::BrainState; use tonic::transport::{Identity, Server, ServerTlsConfig}; #[tokio::main] async fn main() { env_logger::builder().filter_level(log::LevelFilter::Debug).init(); - let db_connection = db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); - let state = Arc::new(BrainState { db_connection }); + let db = db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); + let db_arc = Arc::new(db); let addr = BRAIN_GRPC_ADDR.parse().unwrap(); - let snp_daemon_server = BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone())); - let snp_cli_server = BrainVmCliServer::new(BrainVmCliForReal::new(state.clone())); + let snp_daemon_server = BrainVmDaemonServer::new(BrainVmDaemonForReal::new(db_arc.clone())); + let snp_cli_server = BrainVmCliServer::new(BrainVmCliForReal::new(db_arc.clone())); let general_service_server = - BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone())); + BrainGeneralCliServer::new(BrainGeneralCliForReal::new(db_arc.clone())); let cert = std::fs::read_to_string(CERT_PATH).unwrap(); let key = std::fs::read_to_string(CERT_KEY_PATH).unwrap(); diff --git a/src/grpc.rs b/src/grpc.rs index 415f778..7209182 100644 --- a/src/grpc.rs +++ b/src/grpc.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] use crate::constants::{ACCOUNT, ADMIN_ACCOUNTS, VM_NODE}; -use crate::{db, BrainState}; +use crate::db; use detee_shared::app_proto::{AppContract, AppNodeListResp}; use detee_shared::{ common_proto::{Empty, Pubkey}, @@ -15,6 +15,7 @@ use detee_shared::{ }, }; use nanoid::nanoid; +use surrealdb::{engine::remote::ws::Client, Surreal}; use log::info; use std::pin::Pin; @@ -26,12 +27,12 @@ use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; pub struct BrainGeneralCliForReal { - state: Arc, + pub db: Arc>, } impl BrainGeneralCliForReal { - pub fn new(state: Arc) -> Self { - Self { state } + pub fn new(db: Arc>) -> Self { + Self { db } } } @@ -249,12 +250,12 @@ impl From for db::VmNodeResources { } pub struct BrainVmDaemonForReal { - pub state: Arc, + pub db: Arc>, } impl BrainVmDaemonForReal { - pub fn new(state: Arc) -> Self { - Self { state } + pub fn new(db: Arc>) -> Self { + Self { db } } } @@ -284,12 +285,11 @@ impl BrainVmDaemon for BrainVmDaemonForReal { max_ports_per_vm: 0, offline_minutes: 0, } - .register(&self.state.db_connection) + .register(&self.db) .await?; info!("Sending existing contracts to {}", req.node_pubkey); - let contracts = - db::ActiveVmWithNode::list_by_node(&self.state.db_connection, &req.node_pubkey).await?; + let contracts = db::ActiveVmWithNode::list_by_node(&self.db, &req.node_pubkey).await?; let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { for contract in contracts { @@ -317,12 +317,11 @@ impl BrainVmDaemon for BrainVmDaemonForReal { let (tx, rx) = mpsc::channel(6); { - let state = self.state.clone(); + let db = self.db.clone(); let pubkey = pubkey.clone(); let tx = tx.clone(); tokio::spawn(async move { - match db::listen_for_node::(&state.db_connection, &pubkey, tx).await - { + match db::listen_for_node::(&db, &pubkey, tx).await { Ok(()) => log::info!("db::VmContract::listen_for_node ended for {pubkey}"), Err(e) => { log::warn!("db::VmContract::listen_for_node errored for {pubkey}: {e}") @@ -331,26 +330,19 @@ impl BrainVmDaemon for BrainVmDaemonForReal { }); } { - let state = self.state.clone(); + let db = self.db.clone(); let pubkey = pubkey.clone(); let tx = tx.clone(); tokio::spawn(async move { - let _ = - db::listen_for_node::(&state.db_connection, &pubkey, tx.clone()) - .await; + let _ = db::listen_for_node::(&db, &pubkey, tx.clone()).await; }); } { - let state = self.state.clone(); + let db = self.db.clone(); let pubkey = pubkey.clone(); let tx = tx.clone(); tokio::spawn(async move { - let _ = db::listen_for_node::( - &state.db_connection, - &pubkey, - tx.clone(), - ) - .await; + let _ = db::listen_for_node::(&db, &pubkey, tx.clone()).await; }); } @@ -391,26 +383,21 @@ impl BrainVmDaemon for BrainVmDaemonForReal { // also handle failure properly if !new_vm_resp.error.is_empty() { db::NewVmReq::submit_error( - &self.state.db_connection, + &self.db, &new_vm_resp.uuid, new_vm_resp.error, ) .await?; } else { db::upsert_record( - &self.state.db_connection, + &self.db, "measurement_args", &new_vm_resp.uuid, new_vm_resp.args.clone(), ) .await?; if let Some(args) = new_vm_resp.args { - db::ActiveVm::activate( - &self.state.db_connection, - &new_vm_resp.uuid, - args, - ) - .await?; + db::ActiveVm::activate(&self.db, &new_vm_resp.uuid, args).await?; } } } @@ -420,7 +407,7 @@ impl BrainVmDaemon for BrainVmDaemonForReal { } Some(vm_daemon_message::Msg::VmNodeResources(node_resources)) => { let node_resources: db::VmNodeResources = node_resources.into(); - node_resources.merge(&self.state.db_connection, &pubkey).await?; + node_resources.merge(&self.db, &pubkey).await?; } _ => {} }, @@ -444,15 +431,13 @@ impl BrainGeneralCli for BrainGeneralCliForReal { async fn get_balance(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; - Ok(Response::new(db::Account::get(&self.state.db_connection, &req.pubkey).await?.into())) + Ok(Response::new(db::Account::get(&self.db, &req.pubkey).await?.into())) } async fn report_node(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; let (account, node) = - match db::ActiveVmWithNode::get_by_uuid(&self.state.db_connection, &req.contract) - .await? - { + match db::ActiveVmWithNode::get_by_uuid(&self.db, &req.contract).await? { Some(vm_contract) if vm_contract.admin.key().to_string() == req.admin_pubkey && vm_contract.vm_node.id.key().to_string() == req.node_pubkey => @@ -464,7 +449,7 @@ impl BrainGeneralCli for BrainGeneralCliForReal { return Err(Status::unauthenticated("No contract found by this ID.")); } }; - db::Report::create(&self.state.db_connection, account, node, req.reason).await?; + db::Report::create(&self.db, account, node, req.reason).await?; Ok(Response::new(Empty {})) } @@ -473,7 +458,7 @@ impl BrainGeneralCli for BrainGeneralCliForReal { req: Request, ) -> Result, Status> { let _ = check_sig_from_req(req)?; - let operators = db::Operator::list(&self.state.db_connection).await?; + let operators = db::Operator::list(&self.db).await?; let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { for op in operators { @@ -488,9 +473,7 @@ impl BrainGeneralCli for BrainGeneralCliForReal { &self, req: Request, ) -> Result, Status> { - match db::Operator::inspect_nodes(&self.state.db_connection, &req.into_inner().pubkey) - .await? - { + match db::Operator::inspect_nodes(&self.db, &req.into_inner().pubkey).await? { (Some(op), vm_nodes, app_nodes) => Ok(Response::new(InspectOperatorResp { operator: Some(op.into()), vm_nodes: vm_nodes.into_iter().map(|n| n.into()).collect(), @@ -534,7 +517,7 @@ impl BrainGeneralCli for BrainGeneralCliForReal { async fn airdrop(&self, req: Request) -> Result, Status> { check_admin_key(&req)?; let req = check_sig_from_req(req)?; - db::Account::airdrop(&self.state.db_connection, &req.pubkey, req.tokens).await?; + db::Account::airdrop(&self.db, &req.pubkey, req.tokens).await?; Ok(Response::new(Empty {})) } @@ -602,12 +585,12 @@ impl BrainGeneralCli for BrainGeneralCliForReal { } pub struct BrainVmCliForReal { - state: Arc, + pub db: Arc>, } impl BrainVmCliForReal { - pub fn new(state: Arc) -> Self { - Self { state } + pub fn new(db: Arc>) -> Self { + Self { db } } } @@ -619,26 +602,20 @@ impl BrainVmCli for BrainVmCliForReal { async fn new_vm(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; info!("New VM requested via CLI: {req:?}"); - if db::Account::is_banned_by_node( - &self.state.db_connection, - &req.admin_pubkey, - &req.node_pubkey, - ) - .await? - { + if db::Account::is_banned_by_node(&self.db, &req.admin_pubkey, &req.node_pubkey).await? { return Err(Status::permission_denied("This operator banned you. What did you do?")); } let new_vm_req: db::NewVmReq = req.into(); let id = new_vm_req.id.key().to_string(); - let state = self.state.clone(); + let db = self.db.clone(); let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { - let _ = oneshot_tx.send(db::NewVmResp::listen(&state.db_connection, &id).await); + let _ = oneshot_tx.send(db::NewVmResp::listen(&db, &id).await); }); - new_vm_req.submit(&self.state.db_connection).await?; + new_vm_req.submit(&self.db).await?; match oneshot_rx.await { Ok(new_vm_resp) => Ok(Response::new(new_vm_resp?.into())), @@ -701,7 +678,7 @@ impl BrainVmCli for BrainVmCliForReal { let mut contracts = Vec::new(); if !req.uuid.is_empty() { if let Some(specific_contract) = - db::ActiveVmWithNode::get_by_uuid(&self.state.db_connection, &req.uuid).await? + db::ActiveVmWithNode::get_by_uuid(&self.db, &req.uuid).await? { if specific_contract.admin.key().to_string() == req.wallet { contracts.push(specific_contract); @@ -709,15 +686,11 @@ impl BrainVmCli for BrainVmCliForReal { // TODO: allow operator to inspect contracts } } else if req.as_operator { - contracts.append( - &mut db::ActiveVmWithNode::list_by_operator(&self.state.db_connection, &req.wallet) - .await?, - ); + contracts + .append(&mut db::ActiveVmWithNode::list_by_operator(&self.db, &req.wallet).await?); } else { - contracts.append( - &mut db::ActiveVmWithNode::list_by_admin(&self.state.db_connection, &req.wallet) - .await?, - ); + contracts + .append(&mut db::ActiveVmWithNode::list_by_admin(&self.db, &req.wallet).await?); } let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { @@ -735,7 +708,7 @@ impl BrainVmCli for BrainVmCliForReal { ) -> Result, tonic::Status> { let req = check_sig_from_req(req)?; info!("CLI requested ListVmNodesStream: {req:?}"); - let nodes = db::VmNodeWithReports::find_by_filters(&self.state.db_connection, req).await?; + let nodes = db::VmNodeWithReports::find_by_filters(&self.db, req).await?; let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { for node in nodes { @@ -753,7 +726,7 @@ impl BrainVmCli for BrainVmCliForReal { let req = check_sig_from_req(req)?; info!("Unknown CLI requested ListVmNodesStream: {req:?}"); // TODO: optimize this query so that it gets only one node - let nodes = db::VmNodeWithReports::find_by_filters(&self.state.db_connection, req).await?; + let nodes = db::VmNodeWithReports::find_by_filters(&self.db, req).await?; if let Some(node) = nodes.into_iter().next() { return Ok(Response::new(node.into())); } diff --git a/src/lib.rs b/src/lib.rs index 10377e2..9e4aaa0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,4 @@ -use surrealdb::{engine::remote::ws::Client, Surreal}; - pub mod constants; pub mod db; pub mod grpc; pub mod old_brain; - -pub struct BrainState { - pub db_connection: Surreal, -} diff --git a/tests/common/prepare_test_env.rs b/tests/common/prepare_test_env.rs index 14b2ec4..d6310b1 100644 --- a/tests/common/prepare_test_env.rs +++ b/tests/common/prepare_test_env.rs @@ -7,7 +7,6 @@ use hyper_util::rt::TokioIo; use std::net::SocketAddr; use std::sync::Arc; use surreal_brain::grpc::{BrainGeneralCliForReal, BrainVmCliForReal, BrainVmDaemonForReal}; -use surreal_brain::BrainState; use tokio::io::DuplexStream; use tokio::{net::TcpListener, sync::OnceCell}; use tonic::transport::{Channel, Endpoint, Server, Uri}; @@ -23,73 +22,30 @@ pub const DB_NAME: &str = "test_migration_db"; pub static DB_STATE: OnceCell<()> = OnceCell::const_new(); pub async fn prepare_test_db() -> Surreal { - let db_connection = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); + let db = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); DB_STATE .get_or_init(|| async { - // surreal_brain::db::init(DB_URL, DB_NS, DB_NAME) - // .await - // .expect("Failed to initialize the database"); - - // surreal_brain::db::DB - // .set(surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap()) - // .unwrap(); - let old_brain_data = surreal_brain::old_brain::BrainData::load_from_disk().unwrap(); - db_connection.query(format!("REMOVE DATABASE {DB_NAME}")).await.unwrap(); - db_connection - .query(std::fs::read_to_string("interim_tables.surql").unwrap()) - .await - .unwrap(); - surreal_brain::db::migration0(&db_connection, &old_brain_data).await.unwrap(); + db.query(format!("REMOVE DATABASE {DB_NAME}")).await.unwrap(); + db.query(std::fs::read_to_string("interim_tables.surql").unwrap()).await.unwrap(); + surreal_brain::db::migration0(&db, &old_brain_data).await.unwrap(); }) .await; - db_connection + db } -/* -use db::DB; - -#[tokio::main] -async fn main() { - env_logger::builder().filter_level(log::LevelFilter::Debug).init(); - // db::init(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); - - DB.set(db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap()).unwrap(); - - let db_connection = db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); - let state = Arc::new(BrainState { db_connection }); - - let addr = BRAIN_GRPC_ADDR.parse().unwrap(); - - let snp_daemon_server = BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone())); - let snp_cli_server = BrainVmCliServer::new(BrainVmCliForReal::new(state.clone())); - let general_service_server = - BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone())); - - let cert = std::fs::read_to_string(CERT_PATH).unwrap(); - let key = std::fs::read_to_string(CERT_KEY_PATH).unwrap(); - - let identity = Identity::from_pem(cert, key); - - Server::builder() - .tls_config(ServerTlsConfig::new().identity(identity)) - .unwrap() - .add_service(snp_daemon_server) - .add_service(snp_cli_server) - */ - pub async fn run_service_in_background() -> SocketAddr { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { - let db_connection = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); - let state = Arc::new(BrainState { db_connection }); + let db = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); + let db_arc = Arc::new(db); Server::builder() - .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone()))) - .add_service(BrainVmCliServer::new(BrainVmCliForReal::new(state.clone()))) - .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone()))) + .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal::new(db_arc.clone()))) + .add_service(BrainVmCliServer::new(BrainVmCliForReal::new(db_arc.clone()))) + .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal::new(db_arc.clone()))) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await .unwrap(); @@ -104,13 +60,13 @@ pub async fn run_service_for_stream_server() -> DuplexStream { let (client, server) = tokio::io::duplex(1024); tokio::spawn(async move { - let db_connection = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); - let state = Arc::new(BrainState { db_connection }); + let db = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); + let db_arc = Arc::new(db); tonic::transport::Server::builder() - .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone()))) - .add_service(BrainVmCliServer::new(BrainVmCliForReal::new(state.clone()))) - .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone()))) + .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal::new(db_arc.clone()))) + .add_service(BrainVmCliServer::new(BrainVmCliForReal::new(db_arc.clone()))) + .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal::new(db_arc.clone()))) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await });