diff --git a/src/bin/brain.rs b/src/bin/brain.rs index f995272..5d6abdd 100644 --- a/src/bin/brain.rs +++ b/src/bin/brain.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use detee_shared::general_proto::brain_general_cli_server::BrainGeneralCliServer; use detee_shared::vm_proto::brain_vm_cli_server::BrainVmCliServer; use detee_shared::vm_proto::brain_vm_daemon_server::BrainVmDaemonServer; @@ -7,17 +9,22 @@ use surreal_brain::constants::{ use surreal_brain::db; use surreal_brain::grpc::BrainVmCliForReal; use surreal_brain::grpc::{BrainGeneralCliForReal, BrainVmDaemonForReal}; +use surreal_brain::BrainState; use tonic::transport::{Identity, Server, ServerTlsConfig}; #[tokio::main] async fn main() { env_logger::builder().filter_level(log::LevelFilter::Debug).init(); - db::init(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); + + let db_connection = db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); + let state = Arc::new(BrainState { db_connection }); + let addr = BRAIN_GRPC_ADDR.parse().unwrap(); - let snp_daemon_server = BrainVmDaemonServer::new(BrainVmDaemonForReal {}); - let snp_cli_server = BrainVmCliServer::new(BrainVmCliForReal {}); - let general_service_server = BrainGeneralCliServer::new(BrainGeneralCliForReal {}); + let snp_daemon_server = BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone())); + let snp_cli_server = BrainVmCliServer::new(BrainVmCliForReal::new(state.clone())); + let general_service_server = + BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone())); let cert = std::fs::read_to_string(CERT_PATH).unwrap(); let key = std::fs::read_to_string(CERT_KEY_PATH).unwrap(); diff --git a/src/bin/migration0.rs b/src/bin/migration0.rs index 68e0207..101899f 100644 --- a/src/bin/migration0.rs +++ b/src/bin/migration0.rs @@ -2,7 +2,6 @@ // and dangling impls from the model use std::error::Error; use surreal_brain::constants::{DB_ADDRESS, DB_NAME, DB_NS}; -use surreal_brain::db::init; use surreal_brain::{db, old_brain}; #[tokio::main] @@ -10,9 +9,9 @@ async fn main() -> Result<(), Box> { let old_brain_data = old_brain::BrainData::load_from_disk()?; // println!("{}", serde_yaml::to_string(&old_brain_data)?); - init(DB_ADDRESS, DB_NS, DB_NAME).await?; + let db_connection = db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); - db::migration0(&old_brain_data).await?; + db::migration0(&db_connection, &old_brain_data).await?; Ok(()) } diff --git a/src/constants.rs b/src/constants.rs index 9daf5d4..d7f0c30 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -10,6 +10,8 @@ pub const DB_NAME: &str = "migration"; pub const DB_USER: &str = "root"; pub const DB_PASS: &str = "root"; +pub const DB_SCHEMA_FILE: &str = "interim_tables.surql"; + pub const ADMIN_ACCOUNTS: &[&str] = &[ "x52w7jARC5erhWWK65VZmjdGXzBK6ZDgfv1A283d8XK", "FHuecMbeC1PfjkW2JKyoicJAuiU7khgQT16QUB3Q1XdL", diff --git a/src/db.rs b/src/db.rs index 4adec67..09c2734 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,11 +1,12 @@ +use std::str::FromStr; + pub use crate::constants::{ - ACCOUNT, ACTIVE_VM, DB_ADDRESS, DB_NAME, DB_NS, DB_PASS, DB_USER, DELETED_VM, ID_ALPHABET, - NEW_VM_REQ, UPDATE_VM_REQ, VM_CONTRACT, VM_NODE, + ACCOUNT, ACTIVE_VM, DB_ADDRESS, DB_NAME, DB_NS, DB_PASS, DB_SCHEMA_FILE, DB_USER, DELETED_VM, + ID_ALPHABET, NEW_VM_REQ, UPDATE_VM_REQ, VM_CONTRACT, VM_NODE, }; use crate::old_brain; use serde::{Deserialize, Serialize}; -use std::{str::FromStr, sync::LazyLock}; use surrealdb::{ engine::remote::ws::{Client, Ws}, opt::auth::Root, @@ -15,49 +16,57 @@ use surrealdb::{ use tokio::sync::mpsc::Sender; use tokio_stream::StreamExt as _; -pub static DB: LazyLock> = LazyLock::new(Surreal::init); - #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Internal DB error: {0}")] DataBase(#[from] surrealdb::Error), #[error("Daemon channel got closed: {0}")] DaemonConnection(#[from] tokio::sync::mpsc::error::SendError), + #[error(transparent)] + StdIo(#[from] std::io::Error), } -pub async fn init(db_address: &str, ns: &str, db: &str) -> surrealdb::Result<()> { - DB.connect::(db_address).await?; +pub async fn db_connection(db_address: &str, ns: &str, db: &str) -> Result, Error> { + let db_connection: Surreal = Surreal::init(); + db_connection.connect::(db_address).await?; // Sign in to the server - DB.signin(Root { username: DB_USER, password: DB_PASS }).await?; - DB.use_ns(ns).use_db(db).await?; - Ok(()) + db_connection.signin(Root { username: DB_USER, password: DB_PASS }).await?; + db_connection.use_ns(ns).use_db(db).await?; + Ok(db_connection) } pub async fn upsert_record( + db: &Surreal, table: &str, id: &str, my_record: SomeRecord, ) -> Result<(), Error> { #[derive(Deserialize)] struct Wrapper {} - let _: Option = DB.create((table, id)).content(my_record).await?; + let _: Option = db.create((table, id)).content(my_record).await?; Ok(()) } -pub async fn migration0(old_data: &old_brain::BrainData) -> surrealdb::Result<()> { +pub async fn migration0( + db: &Surreal, + old_data: &old_brain::BrainData, +) -> Result<(), Error> { let accounts: Vec = old_data.into(); let vm_nodes: Vec = old_data.into(); let app_nodes: Vec = old_data.into(); let vm_contracts: Vec = old_data.into(); + let schema = std::fs::read_to_string(DB_SCHEMA_FILE)?; + db.query(schema).await?; + println!("Inserting accounts..."); - let _: Vec = DB.insert(()).content(accounts).await?; + let _: Vec = db.insert(()).content(accounts).await?; println!("Inserting vm nodes..."); - let _: Vec = DB.insert(()).content(vm_nodes).await?; + let _: Vec = db.insert(()).content(vm_nodes).await?; println!("Inserting app nodes..."); - let _: Vec = DB.insert(()).content(app_nodes).await?; + let _: Vec = db.insert(()).content(app_nodes).await?; println!("Inserting vm contracts..."); - let _: Vec = DB.insert("vm_contract").relation(vm_contracts).await?; + let _: Vec = db.insert("vm_contract").relation(vm_contracts).await?; Ok(()) } @@ -72,9 +81,9 @@ pub struct Account { } impl Account { - pub async fn get(address: &str) -> Result { + pub async fn get(db: &Surreal, address: &str) -> Result { let id = (ACCOUNT, address); - let account: Option = DB.select(id).await?; + let account: Option = db.select(id).await?; let account = match account { Some(account) => account, None => { @@ -84,9 +93,9 @@ impl Account { Ok(account) } - pub async fn airdrop(account: &str, tokens: u64) -> Result<(), Error> { + pub async fn airdrop(db: &Surreal, account: &str, tokens: u64) -> Result<(), Error> { let tokens = tokens.saturating_mul(1_000_000_000); - let _ = DB + let _ = db .query(format!("upsert account:{account} SET balance = (balance || 0) + {tokens};")) .await?; Ok(()) @@ -94,8 +103,12 @@ impl Account { } impl Account { - pub async fn is_banned_by_node(user: &str, node: &str) -> Result { - let ban: Option = DB + pub async fn is_banned_by_node( + db: &Surreal, + user: &str, + node: &str, + ) -> Result { + let ban: Option = db .query(format!( "(select operator->ban[0] as ban from vm_node:{node} @@ -139,15 +152,15 @@ pub struct VmNodeResources { } impl VmNodeResources { - pub async fn merge(self, node_id: &str) -> Result<(), Error> { - let _: Option = DB.update((VM_NODE, node_id)).merge(self).await?; + pub async fn merge(self, db: &Surreal, node_id: &str) -> Result<(), Error> { + let _: Option = db.update((VM_NODE, node_id)).merge(self).await?; Ok(()) } } impl VmNode { - pub async fn register(self) -> Result<(), Error> { - let _: Option = DB.upsert(self.id.clone()).content(self).await?; + pub async fn register(self, db: &Surreal) -> Result<(), Error> { + let _: Option = db.upsert(self.id.clone()).content(self).await?; Ok(()) } } @@ -176,6 +189,7 @@ impl VmNodeWithReports { // TODO: find a more elegant way to do this than importing gRPC in the DB module // https://en.wikipedia.org/wiki/Dependency_inversion_principle pub async fn find_by_filters( + db: &Surreal, filters: detee_shared::snp::pb::vm_proto::VmNodeFilters, ) -> Result, Error> { let mut query = format!( @@ -208,7 +222,7 @@ impl VmNodeWithReports { query += &format!("&& ip = '{}' ", filters.ip); } query += ";"; - let mut result = DB.query(query).await?; + let mut result = db.query(query).await?; let vm_nodes: Vec = result.take(0)?; Ok(vm_nodes) } @@ -263,27 +277,27 @@ pub struct NewVmReq { } impl NewVmReq { - pub async fn get(id: &str) -> Result, Error> { - let new_vm_req: Option = DB.select((NEW_VM_REQ, id)).await?; + pub async fn get(db: &Surreal, id: &str) -> Result, Error> { + let new_vm_req: Option = db.select((NEW_VM_REQ, id)).await?; Ok(new_vm_req) } - pub async fn delete(id: &str) -> Result<(), Error> { - let _: Option = DB.delete((NEW_VM_REQ, id)).await?; + pub async fn delete(db: &Surreal, id: &str) -> Result<(), Error> { + let _: Option = db.delete((NEW_VM_REQ, id)).await?; Ok(()) } - pub async fn submit_error(id: &str, error: String) -> Result<(), Error> { + pub async fn submit_error(db: &Surreal, id: &str, error: String) -> Result<(), Error> { #[derive(Serialize)] struct NewVmError { error: String, } - let _: Option = DB.update((NEW_VM_REQ, id)).merge(NewVmError { error }).await?; + let _: Option = db.update((NEW_VM_REQ, id)).merge(NewVmError { error }).await?; Ok(()) } - pub async fn submit(self) -> Result<(), Error> { - let _: Vec = DB.insert(NEW_VM_REQ).relation(self).await?; + pub async fn submit(self, db: &Surreal) -> Result<(), Error> { + let _: Vec = db.insert(NEW_VM_REQ).relation(self).await?; Ok(()) } } @@ -297,8 +311,8 @@ pub enum NewVmResp { } impl NewVmResp { - pub async fn listen(vm_id: &str) -> Result { - let mut resp = DB + pub async fn listen(db: &Surreal, vm_id: &str) -> Result { + let mut resp = db .query(format!("live select * from {NEW_VM_REQ} where id = {NEW_VM_REQ}:{vm_id};")) .query(format!( "live select * from measurement_args where id = measurement_args:{vm_id};" @@ -364,10 +378,11 @@ pub struct ActiveVm { impl ActiveVm { pub async fn activate( + db: &Surreal, id: &str, args: detee_shared::vm_proto::MeasurementArgs, ) -> Result<(), Error> { - let new_vm_req = match NewVmReq::get(id).await? { + let new_vm_req = match NewVmReq::get(db, id).await? { Some(r) => r, None => return Ok(()), }; @@ -415,9 +430,9 @@ impl ActiveVm { collected_at: new_vm_req.created_at, }; - let _: Vec = DB.insert(()).relation(active_vm).await?; + let _: Vec = db.insert(()).relation(active_vm).await?; - NewVmReq::delete(id).await?; + NewVmReq::delete(db, id).await?; Ok(()) } } @@ -444,6 +459,7 @@ pub struct UpdateVmReq { pub async fn listen_for_node< T: Into + std::marker::Unpin + for<'de> Deserialize<'de>, >( + db: &Surreal, node: &str, tx: Sender, ) -> Result<(), Error> { @@ -457,7 +473,7 @@ pub async fn listen_for_node< } }; let mut resp = - DB.query(format!("live select * from {table_name} where out = vm_node:{node};")).await?; + db.query(format!("live select * from {table_name} where out = vm_node:{node};")).await?; let mut live_stream = resp.stream::>(0)?; while let Some(result) = live_stream.next().await { match result { @@ -497,28 +513,31 @@ pub struct DeletedVm { } impl DeletedVm { - pub async fn get_by_uuid(uuid: &str) -> Result, Error> { + pub async fn get_by_uuid(db: &Surreal, uuid: &str) -> Result, Error> { let contract: Option = - DB.query(format!("select * from {DELETED_VM}:{uuid};")).await?.take(0)?; + db.query(format!("select * from {DELETED_VM}:{uuid};")).await?.take(0)?; Ok(contract) } - pub async fn list_by_admin(admin: &str) -> Result, Error> { + pub async fn list_by_admin(db: &Surreal, admin: &str) -> Result, Error> { let mut result = - DB.query(format!("select * from {DELETED_VM} where in = {ACCOUNT}:{admin};")).await?; + db.query(format!("select * from {DELETED_VM} where in = {ACCOUNT}:{admin};")).await?; let contracts: Vec = result.take(0)?; Ok(contracts) } - pub async fn list_by_node(admin: &str) -> Result, Error> { + pub async fn list_by_node(db: &Surreal, admin: &str) -> Result, Error> { let mut result = - DB.query(format!("select * from {DELETED_VM} where out = {VM_NODE}:{admin};")).await?; + db.query(format!("select * from {DELETED_VM} where out = {VM_NODE}:{admin};")).await?; let contracts: Vec = result.take(0)?; Ok(contracts) } - pub async fn list_by_operator(operator: &str) -> Result, Error> { - let mut result = DB + pub async fn list_by_operator( + db: &Surreal, + operator: &str, + ) -> Result, Error> { + let mut result = db .query(format!( "select (select * from ->operator->vm_node<-{DELETED_VM}) as contracts @@ -596,30 +615,33 @@ pub struct ActiveVmWithNode { } impl ActiveVmWithNode { - pub async fn get_by_uuid(uuid: &str) -> Result, Error> { + pub async fn get_by_uuid(db: &Surreal, uuid: &str) -> Result, Error> { let contract: Option = - DB.query(format!("select * from {ACTIVE_VM}:{uuid} fetch out;")).await?.take(0)?; + db.query(format!("select * from {ACTIVE_VM}:{uuid} fetch out;")).await?.take(0)?; Ok(contract) } - pub async fn list_by_admin(admin: &str) -> Result, Error> { - let mut result = DB + pub async fn list_by_admin(db: &Surreal, admin: &str) -> Result, Error> { + let mut result = db .query(format!("select * from {ACTIVE_VM} where in = {ACCOUNT}:{admin} fetch out;")) .await?; let contracts: Vec = result.take(0)?; Ok(contracts) } - pub async fn list_by_node(admin: &str) -> Result, Error> { - let mut result = DB + pub async fn list_by_node(db: &Surreal, admin: &str) -> Result, Error> { + let mut result = db .query(format!("select * from {ACTIVE_VM} where out = {VM_NODE}:{admin} fetch out;")) .await?; let contracts: Vec = result.take(0)?; Ok(contracts) } - pub async fn list_by_operator(operator: &str) -> Result, Error> { - let mut result = DB + pub async fn list_by_operator( + db: &Surreal, + operator: &str, + ) -> Result, Error> { + let mut result = db .query(format!( "select (select * from ->operator->vm_node<-{ACTIVE_VM} fetch out) as contracts @@ -748,11 +770,12 @@ pub struct Report { impl Report { // TODO: test this functionality and remove this comment pub async fn create( + db: &Surreal, from_account: RecordId, to_node: RecordId, reason: String, ) -> Result<(), Error> { - let _: Vec = DB + let _: Vec = db .insert("report") .relation(Report { from_account, to_node, created_at: Datetime::default(), reason }) .await?; @@ -773,8 +796,8 @@ pub struct Operator { } impl Operator { - pub async fn list() -> Result, Error> { - let mut result = DB + pub async fn list(db: &Surreal) -> Result, Error> { + let mut result = db .query( "array::distinct(array::flatten( [ (select operator from vm_node group by operator).operator, @@ -786,15 +809,15 @@ impl Operator { let operator_accounts: Vec = result.take(0)?; let mut operators: Vec = Vec::new(); for account in operator_accounts.iter() { - if let Some(operator) = Self::inspect(&account.key().to_string()).await? { + if let Some(operator) = Self::inspect(db, &account.key().to_string()).await? { operators.push(operator); } } Ok(operators) } - pub async fn inspect(account: &str) -> Result, Error> { - let mut result = DB + pub async fn inspect(db: &Surreal, account: &str) -> Result, Error> { + let mut result = db .query(format!( "$vm_nodes = (select id from vm_node where operator = account:{account}).id; $app_nodes = (select id from app_node where operator = account:{account}).id; @@ -815,10 +838,11 @@ impl Operator { } pub async fn inspect_nodes( + db: &Surreal, account: &str, ) -> Result<(Option, Vec, Vec), Error> { - let operator = Self::inspect(account).await?; - let mut result = DB + let operator = Self::inspect(db, account).await?; + let mut result = db .query(format!( "select *, operator, <-report.* as reports from vm_node where operator = account:{account};" diff --git a/src/grpc.rs b/src/grpc.rs index 93decea..415f778 100644 --- a/src/grpc.rs +++ b/src/grpc.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] use crate::constants::{ACCOUNT, ADMIN_ACCOUNTS, VM_NODE}; -use crate::db; +use crate::{db, BrainState}; use detee_shared::app_proto::{AppContract, AppNodeListResp}; use detee_shared::{ common_proto::{Empty, Pubkey}, @@ -18,13 +18,22 @@ use nanoid::nanoid; use log::info; use std::pin::Pin; +use std::sync::Arc; use surrealdb::RecordId; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::{Stream, StreamExt}; use tonic::{Request, Response, Status, Streaming}; -pub struct BrainGeneralCliForReal {} +pub struct BrainGeneralCliForReal { + state: Arc, +} + +impl BrainGeneralCliForReal { + pub fn new(state: Arc) -> Self { + Self { state } + } +} impl From for AccountBalance { fn from(account: db::Account) -> Self { @@ -239,7 +248,15 @@ impl From for db::VmNodeResources { } } -pub struct BrainVmDaemonForReal {} +pub struct BrainVmDaemonForReal { + pub state: Arc, +} + +impl BrainVmDaemonForReal { + pub fn new(state: Arc) -> Self { + Self { state } + } +} #[tonic::async_trait] impl BrainVmDaemon for BrainVmDaemonForReal { @@ -267,11 +284,12 @@ impl BrainVmDaemon for BrainVmDaemonForReal { max_ports_per_vm: 0, offline_minutes: 0, } - .register() + .register(&self.state.db_connection) .await?; info!("Sending existing contracts to {}", req.node_pubkey); - let contracts = db::ActiveVmWithNode::list_by_node(&req.node_pubkey).await?; + let contracts = + db::ActiveVmWithNode::list_by_node(&self.state.db_connection, &req.node_pubkey).await?; let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { for contract in contracts { @@ -299,10 +317,12 @@ impl BrainVmDaemon for BrainVmDaemonForReal { let (tx, rx) = mpsc::channel(6); { + let state = self.state.clone(); let pubkey = pubkey.clone(); let tx = tx.clone(); tokio::spawn(async move { - match db::listen_for_node::(&pubkey, tx).await { + match db::listen_for_node::(&state.db_connection, &pubkey, tx).await + { Ok(()) => log::info!("db::VmContract::listen_for_node ended for {pubkey}"), Err(e) => { log::warn!("db::VmContract::listen_for_node errored for {pubkey}: {e}") @@ -311,17 +331,26 @@ impl BrainVmDaemon for BrainVmDaemonForReal { }); } { + let state = self.state.clone(); let pubkey = pubkey.clone(); let tx = tx.clone(); tokio::spawn(async move { - let _ = db::listen_for_node::(&pubkey, tx.clone()).await; + let _ = + db::listen_for_node::(&state.db_connection, &pubkey, tx.clone()) + .await; }); } { + let state = self.state.clone(); let pubkey = pubkey.clone(); let tx = tx.clone(); tokio::spawn(async move { - let _ = db::listen_for_node::(&pubkey, tx.clone()).await; + let _ = db::listen_for_node::( + &state.db_connection, + &pubkey, + tx.clone(), + ) + .await; }); } @@ -361,17 +390,27 @@ impl BrainVmDaemon for BrainVmDaemonForReal { // TODO: move new_vm_req to active_vm // also handle failure properly if !new_vm_resp.error.is_empty() { - db::NewVmReq::submit_error(&new_vm_resp.uuid, new_vm_resp.error) - .await?; + db::NewVmReq::submit_error( + &self.state.db_connection, + &new_vm_resp.uuid, + new_vm_resp.error, + ) + .await?; } else { db::upsert_record( + &self.state.db_connection, "measurement_args", &new_vm_resp.uuid, new_vm_resp.args.clone(), ) .await?; if let Some(args) = new_vm_resp.args { - db::ActiveVm::activate(&new_vm_resp.uuid, args).await?; + db::ActiveVm::activate( + &self.state.db_connection, + &new_vm_resp.uuid, + args, + ) + .await?; } } } @@ -381,7 +420,7 @@ impl BrainVmDaemon for BrainVmDaemonForReal { } Some(vm_daemon_message::Msg::VmNodeResources(node_resources)) => { let node_resources: db::VmNodeResources = node_resources.into(); - node_resources.merge(&pubkey).await?; + node_resources.merge(&self.state.db_connection, &pubkey).await?; } _ => {} }, @@ -405,24 +444,27 @@ impl BrainGeneralCli for BrainGeneralCliForReal { async fn get_balance(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; - Ok(Response::new(db::Account::get(&req.pubkey).await?.into())) + Ok(Response::new(db::Account::get(&self.state.db_connection, &req.pubkey).await?.into())) } async fn report_node(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; - let (account, node) = match db::ActiveVmWithNode::get_by_uuid(&req.contract).await? { - Some(vm_contract) - if vm_contract.admin.key().to_string() == req.admin_pubkey - && vm_contract.vm_node.id.key().to_string() == req.node_pubkey => + let (account, node) = + match db::ActiveVmWithNode::get_by_uuid(&self.state.db_connection, &req.contract) + .await? { - (vm_contract.admin, vm_contract.vm_node.id) - } - _ => { - // TODO: Hey, Noor! Please add app contract here. - return Err(Status::unauthenticated("No contract found by this ID.")); - } - }; - db::Report::create(account, node, req.reason).await?; + Some(vm_contract) + if vm_contract.admin.key().to_string() == req.admin_pubkey + && vm_contract.vm_node.id.key().to_string() == req.node_pubkey => + { + (vm_contract.admin, vm_contract.vm_node.id) + } + _ => { + // TODO: Hey, Noor! Please add app contract here. + return Err(Status::unauthenticated("No contract found by this ID.")); + } + }; + db::Report::create(&self.state.db_connection, account, node, req.reason).await?; Ok(Response::new(Empty {})) } @@ -431,7 +473,7 @@ impl BrainGeneralCli for BrainGeneralCliForReal { req: Request, ) -> Result, Status> { let _ = check_sig_from_req(req)?; - let operators = db::Operator::list().await?; + let operators = db::Operator::list(&self.state.db_connection).await?; let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { for op in operators { @@ -446,7 +488,9 @@ impl BrainGeneralCli for BrainGeneralCliForReal { &self, req: Request, ) -> Result, Status> { - match db::Operator::inspect_nodes(&req.into_inner().pubkey).await? { + match db::Operator::inspect_nodes(&self.state.db_connection, &req.into_inner().pubkey) + .await? + { (Some(op), vm_nodes, app_nodes) => Ok(Response::new(InspectOperatorResp { operator: Some(op.into()), vm_nodes: vm_nodes.into_iter().map(|n| n.into()).collect(), @@ -490,7 +534,7 @@ impl BrainGeneralCli for BrainGeneralCliForReal { async fn airdrop(&self, req: Request) -> Result, Status> { check_admin_key(&req)?; let req = check_sig_from_req(req)?; - db::Account::airdrop(&req.pubkey, req.tokens).await?; + db::Account::airdrop(&self.state.db_connection, &req.pubkey, req.tokens).await?; Ok(Response::new(Empty {})) } @@ -557,7 +601,15 @@ impl BrainGeneralCli for BrainGeneralCliForReal { } } -pub struct BrainVmCliForReal {} +pub struct BrainVmCliForReal { + state: Arc, +} + +impl BrainVmCliForReal { + pub fn new(state: Arc) -> Self { + Self { state } + } +} #[tonic::async_trait] impl BrainVmCli for BrainVmCliForReal { @@ -567,17 +619,26 @@ impl BrainVmCli for BrainVmCliForReal { async fn new_vm(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; info!("New VM requested via CLI: {req:?}"); - if db::Account::is_banned_by_node(&req.admin_pubkey, &req.node_pubkey).await? { + if db::Account::is_banned_by_node( + &self.state.db_connection, + &req.admin_pubkey, + &req.node_pubkey, + ) + .await? + { return Err(Status::permission_denied("This operator banned you. What did you do?")); } let new_vm_req: db::NewVmReq = req.into(); let id = new_vm_req.id.key().to_string(); + + let state = self.state.clone(); + let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { - let _ = oneshot_tx.send(db::NewVmResp::listen(&id).await); + let _ = oneshot_tx.send(db::NewVmResp::listen(&state.db_connection, &id).await); }); - new_vm_req.submit().await?; + new_vm_req.submit(&self.state.db_connection).await?; match oneshot_rx.await { Ok(new_vm_resp) => Ok(Response::new(new_vm_resp?.into())), @@ -639,16 +700,24 @@ impl BrainVmCli for BrainVmCliForReal { ); let mut contracts = Vec::new(); if !req.uuid.is_empty() { - if let Some(specific_contract) = db::ActiveVmWithNode::get_by_uuid(&req.uuid).await? { + if let Some(specific_contract) = + db::ActiveVmWithNode::get_by_uuid(&self.state.db_connection, &req.uuid).await? + { if specific_contract.admin.key().to_string() == req.wallet { contracts.push(specific_contract); } // TODO: allow operator to inspect contracts } } else if req.as_operator { - contracts.append(&mut db::ActiveVmWithNode::list_by_operator(&req.wallet).await?); + contracts.append( + &mut db::ActiveVmWithNode::list_by_operator(&self.state.db_connection, &req.wallet) + .await?, + ); } else { - contracts.append(&mut db::ActiveVmWithNode::list_by_admin(&req.wallet).await?); + contracts.append( + &mut db::ActiveVmWithNode::list_by_admin(&self.state.db_connection, &req.wallet) + .await?, + ); } let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { @@ -666,7 +735,7 @@ impl BrainVmCli for BrainVmCliForReal { ) -> Result, tonic::Status> { let req = check_sig_from_req(req)?; info!("CLI requested ListVmNodesStream: {req:?}"); - let nodes = db::VmNodeWithReports::find_by_filters(req).await?; + let nodes = db::VmNodeWithReports::find_by_filters(&self.state.db_connection, req).await?; let (tx, rx) = mpsc::channel(6); tokio::spawn(async move { for node in nodes { @@ -684,7 +753,7 @@ impl BrainVmCli for BrainVmCliForReal { let req = check_sig_from_req(req)?; info!("Unknown CLI requested ListVmNodesStream: {req:?}"); // TODO: optimize this query so that it gets only one node - let nodes = db::VmNodeWithReports::find_by_filters(req).await?; + let nodes = db::VmNodeWithReports::find_by_filters(&self.state.db_connection, req).await?; if let Some(node) = nodes.into_iter().next() { return Ok(Response::new(node.into())); } diff --git a/src/lib.rs b/src/lib.rs index 9e4aaa0..10377e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,10 @@ +use surrealdb::{engine::remote::ws::Client, Surreal}; + pub mod constants; pub mod db; pub mod grpc; pub mod old_brain; + +pub struct BrainState { + pub db_connection: Surreal, +} diff --git a/tests/common/prepare_test_env.rs b/tests/common/prepare_test_env.rs index cdeaed8..14b2ec4 100644 --- a/tests/common/prepare_test_env.rs +++ b/tests/common/prepare_test_env.rs @@ -5,45 +5,91 @@ use detee_shared::{ }; use hyper_util::rt::TokioIo; use std::net::SocketAddr; +use std::sync::Arc; use surreal_brain::grpc::{BrainGeneralCliForReal, BrainVmCliForReal, BrainVmDaemonForReal}; +use surreal_brain::BrainState; use tokio::io::DuplexStream; use tokio::{net::TcpListener, sync::OnceCell}; use tonic::transport::{Channel, Endpoint, Server, Uri}; use tower::service_fn; +use surrealdb::engine::remote::ws::Client; +use surrealdb::Surreal; + pub const DB_URL: &str = "localhost:8000"; pub const DB_NS: &str = "test_brain"; pub const DB_NAME: &str = "test_migration_db"; pub static DB_STATE: OnceCell<()> = OnceCell::const_new(); -pub async fn prepare_test_db() { +pub async fn prepare_test_db() -> Surreal { + let db_connection = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); DB_STATE .get_or_init(|| async { - surreal_brain::db::init(DB_URL, DB_NS, DB_NAME) - .await - .expect("Failed to initialize the database"); + // surreal_brain::db::init(DB_URL, DB_NS, DB_NAME) + // .await + // .expect("Failed to initialize the database"); + + // surreal_brain::db::DB + // .set(surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap()) + // .unwrap(); let old_brain_data = surreal_brain::old_brain::BrainData::load_from_disk().unwrap(); - surreal_brain::db::DB.query(format!("REMOVE DATABASE {DB_NAME}")).await.unwrap(); - surreal_brain::db::DB + db_connection.query(format!("REMOVE DATABASE {DB_NAME}")).await.unwrap(); + db_connection .query(std::fs::read_to_string("interim_tables.surql").unwrap()) .await .unwrap(); - surreal_brain::db::migration0(&old_brain_data).await.unwrap(); + surreal_brain::db::migration0(&db_connection, &old_brain_data).await.unwrap(); }) .await; + db_connection } +/* +use db::DB; + +#[tokio::main] +async fn main() { + env_logger::builder().filter_level(log::LevelFilter::Debug).init(); + // db::init(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); + + DB.set(db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap()).unwrap(); + + let db_connection = db::db_connection(DB_ADDRESS, DB_NS, DB_NAME).await.unwrap(); + let state = Arc::new(BrainState { db_connection }); + + let addr = BRAIN_GRPC_ADDR.parse().unwrap(); + + let snp_daemon_server = BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone())); + let snp_cli_server = BrainVmCliServer::new(BrainVmCliForReal::new(state.clone())); + let general_service_server = + BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone())); + + let cert = std::fs::read_to_string(CERT_PATH).unwrap(); + let key = std::fs::read_to_string(CERT_KEY_PATH).unwrap(); + + let identity = Identity::from_pem(cert, key); + + Server::builder() + .tls_config(ServerTlsConfig::new().identity(identity)) + .unwrap() + .add_service(snp_daemon_server) + .add_service(snp_cli_server) + */ + pub async fn run_service_in_background() -> SocketAddr { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { + let db_connection = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); + let state = Arc::new(BrainState { db_connection }); + Server::builder() - .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal {})) - .add_service(BrainVmCliServer::new(BrainVmCliForReal {})) - .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal {})) + .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone()))) + .add_service(BrainVmCliServer::new(BrainVmCliForReal::new(state.clone()))) + .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone()))) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await .unwrap(); @@ -58,10 +104,13 @@ pub async fn run_service_for_stream_server() -> DuplexStream { let (client, server) = tokio::io::duplex(1024); tokio::spawn(async move { + let db_connection = surreal_brain::db::db_connection(DB_URL, DB_NS, DB_NAME).await.unwrap(); + let state = Arc::new(BrainState { db_connection }); + tonic::transport::Server::builder() - .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal {})) - .add_service(BrainVmCliServer::new(BrainVmCliForReal {})) - .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal {})) + .add_service(BrainGeneralCliServer::new(BrainGeneralCliForReal::new(state.clone()))) + .add_service(BrainVmCliServer::new(BrainVmCliForReal::new(state.clone()))) + .add_service(BrainVmDaemonServer::new(BrainVmDaemonForReal::new(state.clone()))) .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))) .await }); diff --git a/tests/common/vm_cli_utils.rs b/tests/common/vm_cli_utils.rs index 99b3fa9..6c8d059 100644 --- a/tests/common/vm_cli_utils.rs +++ b/tests/common/vm_cli_utils.rs @@ -3,9 +3,16 @@ use detee_shared::vm_proto; use detee_shared::vm_proto::brain_vm_cli_client::BrainVmCliClient; use surreal_brain::constants::{ACTIVE_VM, NEW_VM_REQ}; use surreal_brain::db; +use surrealdb::engine::remote::ws::Client; +use surrealdb::Surreal; use tonic::transport::Channel; -pub async fn create_new_vm(key: Key, node_pubkey: String, brain_channel: Channel) -> String { +pub async fn create_new_vm( + db: &Surreal, + key: Key, + node_pubkey: String, + brain_channel: Channel, +) -> String { let new_vm_req = vm_proto::NewVmReq { admin_pubkey: key.pubkey.clone(), node_pubkey, @@ -26,14 +33,14 @@ pub async fn create_new_vm(key: Key, node_pubkey: String, brain_channel: Channel tokio::time::sleep(tokio::time::Duration::from_millis(700)).await; let vm_req_db: Option = - db::DB.select((NEW_VM_REQ, new_vm_resp.uuid.clone())).await.unwrap(); + db.select((NEW_VM_REQ, new_vm_resp.uuid.clone())).await.unwrap(); if let Some(new_vm_req) = vm_req_db { panic!("New VM request found in DB: {:?}", new_vm_req); } let active_vm_op: Option = - db::DB.select((ACTIVE_VM, new_vm_resp.uuid.clone())).await.unwrap(); + db.select((ACTIVE_VM, new_vm_resp.uuid.clone())).await.unwrap(); let active_vm = active_vm_op.unwrap(); active_vm.id.key().to_string() diff --git a/tests/grpc_test.rs b/tests/grpc_test.rs index b42d410..824b951 100644 --- a/tests/grpc_test.rs +++ b/tests/grpc_test.rs @@ -13,7 +13,7 @@ use detee_shared::{ }; use futures::StreamExt; use surreal_brain::constants::VM_NODE; -use surreal_brain::db::{self, VmNodeWithReports}; +use surreal_brain::db::VmNodeWithReports; mod common; @@ -39,19 +39,19 @@ async fn test_general_balance() { #[tokio::test] async fn test_vm_creation() { - prepare_test_db().await; + let db = prepare_test_db().await; let brain_channel = run_service_for_stream().await; let daemon_key = mock_vm_daemon(brain_channel.clone()).await; let key = Key::new(); - let _ = create_new_vm(key.clone(), daemon_key.clone(), brain_channel.clone()).await; + let _ = create_new_vm(&db, key.clone(), daemon_key.clone(), brain_channel.clone()).await; } #[tokio::test] async fn test_report_node() { - prepare_test_db().await; + let db = prepare_test_db().await; let brain_channel = run_service_for_stream().await; let daemon_key = mock_vm_daemon(brain_channel.clone()).await; @@ -73,7 +73,8 @@ async fn test_report_node() { println!("Report error: {:?}", report_error); assert_eq!(report_error.message(), "No contract found by this ID."); - let active_vm_id = create_new_vm(key.clone(), daemon_key.clone(), brain_channel.clone()).await; + let active_vm_id = + create_new_vm(&db, key.clone(), daemon_key.clone(), brain_channel.clone()).await; let reason = String::from("something went wrong on vm"); let report_req = ReportNodeReq { @@ -89,7 +90,7 @@ async fn test_report_node() { .unwrap() .into_inner(); - let vm_nodes: Vec = db::DB + let vm_nodes: Vec = db .query(format!( "SELECT *, <-report.* as reports FROM {VM_NODE} WHERE id = {VM_NODE}:{daemon_key};" )) diff --git a/tests/grpc_vm_daemon_test.rs b/tests/grpc_vm_daemon_test.rs index 681ae13..02da943 100644 --- a/tests/grpc_vm_daemon_test.rs +++ b/tests/grpc_vm_daemon_test.rs @@ -28,7 +28,7 @@ async fn test_reg_vm_node() { #[tokio::test] async fn test_brain_message() { env_logger::builder().filter_level(log::LevelFilter::Info).init(); - let _ = prepare_test_db().await; + let db = prepare_test_db().await; let brain_channel = run_service_for_stream().await; let daemon_key = mock_vm_daemon(brain_channel.clone()).await; @@ -51,8 +51,7 @@ async fn test_brain_message() { assert!(new_vm_resp.uuid.len() == 40); let id = ("measurement_args", new_vm_resp.uuid); - let data_in_db: detee_shared::vm_proto::MeasurementArgs = - surreal_brain::db::DB.select(id).await.unwrap().unwrap(); + let data_in_db: detee_shared::vm_proto::MeasurementArgs = db.select(id).await.unwrap().unwrap(); assert_eq!(data_in_db, new_vm_resp.args.unwrap()); }