From cd44d824637cd1d9f37f71bbf3c03977e1b3f1f1 Mon Sep 17 00:00:00 2001 From: ghe0 Date: Tue, 6 May 2025 04:20:51 +0300 Subject: [PATCH] VM updates and VM deletion --- .env | 4 +- interim_tables.surql | 2 + src/bin/brain.rs | 20 ++- src/constants.rs | 6 +- src/db/mod.rs | 8 +- src/db/vm.rs | 283 ++++++++++++++++++++++++++----- src/grpc/types.rs | 44 ++++- src/grpc/vm.rs | 125 ++++++++++---- tests/common/prepare_test_env.rs | 6 +- 9 files changed, 395 insertions(+), 103 deletions(-) diff --git a/.env b/.env index e58ac39..8984b2f 100644 --- a/.env +++ b/.env @@ -2,4 +2,6 @@ DB_URL = "localhost:8000" DB_USER = "root" DB_PASS = "root" DB_NAMESPACE = "brain" -DB_NAME = "migration" \ No newline at end of file +DB_NAME = "migration" +CERT_PATH = "./tmp/brain-crt.pem" +CERT_KEY_PATH = "./tmp/brain-key.pem" diff --git a/interim_tables.surql b/interim_tables.surql index 8911c34..8c4d0ba 100644 --- a/interim_tables.surql +++ b/interim_tables.surql @@ -60,6 +60,8 @@ DEFINE FIELD dtrfs_sha ON TABLE update_vm_req TYPE string; DEFINE FIELD dtrfs_url ON TABLE update_vm_req TYPE string; DEFINE FIELD kernel_sha ON TABLE update_vm_req TYPE string; DEFINE FIELD kernel_url ON TABLE update_vm_req TYPE string; +DEFINE FIELD created_at ON TABLE update_vm_req TYPE datetime; +DEFINE FIELD error ON TABLE update_vm_req TYPE string; DEFINE TABLE deleted_vm TYPE RELATION FROM account TO vm_node SCHEMAFULL; DEFINE FIELD hostname ON TABLE deleted_vm TYPE string; diff --git a/src/bin/brain.rs b/src/bin/brain.rs index 6cf2e26..62e1fb8 100644 --- a/src/bin/brain.rs +++ b/src/bin/brain.rs @@ -12,14 +12,16 @@ use tonic::transport::{Identity, Server, ServerTlsConfig}; #[tokio::main] async fn main() { - dotenv().ok(); + if dotenv::from_filename("/etc/detee/brain/config.ini").is_err() { + dotenv().ok(); + } env_logger::builder().filter_level(log::LevelFilter::Debug).init(); - let db_url = std::env::var("DB_URL").expect("DB_URL not set in .env"); - let db_user = std::env::var("DB_USER").expect("DB_USER not set in .env"); - let db_pass = std::env::var("DB_PASS").expect("DB_PASS not set in .env"); - let db_ns = std::env::var("DB_NAMESPACE").expect("DB_NAMESPACE not set in .env"); - let db_name = std::env::var("DB_NAME").expect("DB_NAME not set in .env"); + let db_url = std::env::var("DB_URL").expect("the environment variable DB_URL is not set"); + let db_user = std::env::var("DB_USER").expect("the environment variable DB_USER is not set"); + let db_pass = std::env::var("DB_PASS").expect("the environment variable DB_PASS is not set"); + let db_ns = std::env::var("DB_NAMESPACE").expect("the env variable DB_NAMESPACE is not set"); + let db_name = std::env::var("DB_NAME").expect("the environment variable DB_NAME is not set"); let db = db::db_connection(&db_url, &db_user, &db_pass, &db_ns, &db_name).await.unwrap(); let db_arc = Arc::new(db); @@ -30,8 +32,10 @@ async fn main() { let snp_cli_server = BrainVmCliServer::new(VmCliServer::new(db_arc.clone())); let general_service_server = BrainGeneralCliServer::new(GeneralCliServer::new(db_arc.clone())); - let cert = std::fs::read_to_string(CERT_PATH).unwrap(); - let key = std::fs::read_to_string(CERT_KEY_PATH).unwrap(); + let cert_path = std::env::var("CERT_PATH").unwrap_or(CERT_PATH.to_string()); + let key_path = std::env::var("CERT_KEY_PATH").unwrap_or(CERT_KEY_PATH.to_string()); + let cert = std::fs::read_to_string(cert_path).unwrap(); + let key = std::fs::read_to_string(key_path).unwrap(); let identity = Identity::from_pem(cert, key); diff --git a/src/constants.rs b/src/constants.rs index c505d60..c1a3ca1 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,6 +1,7 @@ pub const BRAIN_GRPC_ADDR: &str = "0.0.0.0:31337"; -pub const CERT_PATH: &str = "./tmp/brain-crt.pem"; -pub const CERT_KEY_PATH: &str = "./tmp/brain-key.pem"; +pub const CERT_PATH: &str = "/etc/detee/brain/brain-crt.pem"; +pub const CERT_KEY_PATH: &str = "/etc/detee/brain/brain-key.pem"; +pub const CONFIG_PATH: &str = "/etc/detee/brain/config.ini"; pub const DB_SCHEMA_FILE: &str = "interim_tables.surql"; @@ -15,6 +16,7 @@ pub const OLD_BRAIN_DATA_PATH: &str = "./saved_data.yaml"; pub const ACCOUNT: &str = "account"; pub const VM_NODE: &str = "vm_node"; pub const ACTIVE_VM: &str = "active_vm"; +pub const VM_UPDATE_EVENT: &str = "vm_update_event"; pub const NEW_VM_REQ: &str = "new_vm_req"; pub const UPDATE_VM_REQ: &str = "update_vm_req"; pub const DELETED_VM: &str = "deleted_vm"; diff --git a/src/db/mod.rs b/src/db/mod.rs index 5e2e5f4..24d08ee 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -29,6 +29,7 @@ pub mod prelude { pub use super::general::*; pub use super::vm::*; pub use super::*; + pub use detee_shared::snp::pb::vm_proto::{MeasurementArgs, VmNodeFilters}; } pub async fn db_connection( @@ -78,12 +79,12 @@ pub async fn upsert_record( ) -> Result<(), Error> { #[derive(Deserialize)] struct Wrapper {} - let _: Option = db.create((table, id)).content(my_record).await?; + let _: Option = db.upsert((table, id)).content(my_record).await?; Ok(()) } pub async fn live_vmnode_msgs< - T: Into + std::marker::Unpin + for<'de> Deserialize<'de>, + T: std::fmt::Debug + Into + std::marker::Unpin + for<'de> Deserialize<'de>, >( db: &Surreal, node: &str, @@ -104,12 +105,13 @@ pub async fn live_vmnode_msgs< while let Some(result) = live_stream.next().await { match result { Ok(notification) => { + log::debug!("Got notification for node {node}: {notification:?}"); if notification.action == surrealdb::Action::Create { tx.send(notification.data.into()).await? } } Err(e) => { - log::error!("listen_for_{table_name} DB stream failed for {node}: {e}"); + log::error!("live_vmnode_msgs for {table_name} DB stream failed for {node}: {e}"); return Err(Error::from(e)); } } diff --git a/src/db/vm.rs b/src/db/vm.rs index 8d7ca71..6b87bfc 100644 --- a/src/db/vm.rs +++ b/src/db/vm.rs @@ -2,8 +2,10 @@ use std::str::FromStr; use std::time::Duration; use super::Error; -use crate::constants::{ACCOUNT, ACTIVE_VM, DELETED_VM, NEW_VM_REQ, VM_NODE}; -use crate::db::general::Report; +use crate::constants::{ + ACCOUNT, ACTIVE_VM, DELETED_VM, NEW_VM_REQ, UPDATE_VM_REQ, VM_NODE, VM_UPDATE_EVENT, +}; +use crate::db::{MeasurementArgs, Report, VmNodeFilters}; use crate::old_brain; use serde::{Deserialize, Serialize}; use surrealdb::engine::remote::ws::Client; @@ -80,7 +82,7 @@ impl VmNodeWithReports { // https://en.wikipedia.org/wiki/Dependency_inversion_principle pub async fn find_by_filters( db: &Surreal, - filters: detee_shared::snp::pb::vm_proto::VmNodeFilters, + filters: VmNodeFilters, ) -> Result, Error> { let mut query = format!( "select *, <-report.* as reports from {VM_NODE} where @@ -193,35 +195,54 @@ impl NewVmReq { } /// first string is the vm_id -pub enum NewVmResp { - // TODO: find a more elegant way to do this than importing gRPC in the DB module - // https://en.wikipedia.org/wiki/Dependency_inversion_principle - Args(String, detee_shared::snp::pb::vm_proto::MeasurementArgs), +pub enum WrappedMeasurement { + Args(String, MeasurementArgs), Error(String, String), } -impl NewVmResp { - pub async fn listen(db: &Surreal, vm_id: &str) -> Result { +impl WrappedMeasurement { + /// table must be NEW_VM_REQ or UPDATE_VM_REQ + /// it will however be enforced if you send anything else + pub async fn listen( + db: &Surreal, + vm_id: &str, + table: &str, + ) -> Result { + let table = match table { + UPDATE_VM_REQ => UPDATE_VM_REQ, + _ => NEW_VM_REQ, + }; + #[derive(Deserialize)] + struct ErrorMessage { + error: String, + } let mut resp = db - .query(format!("live select * from {NEW_VM_REQ} where id = {NEW_VM_REQ}:{vm_id};")) + .query(format!("live select error from {table} where id = {NEW_VM_REQ}:{vm_id};")) .query(format!( "live select * from measurement_args where id = measurement_args:{vm_id};" )) .await?; - let mut new_vm_stream = resp.stream::>(0)?; - let mut args_stream = - resp.stream::>(1)?; + let mut error_stream = resp.stream::>(0)?; + let mut args_stream = resp.stream::>(1)?; + + let args: Option = db.delete(("measurement_args", vm_id)).await?; + if let Some(args) = args { + return Ok(Self::Args(vm_id.to_string(), args)); + } tokio::time::timeout(Duration::from_secs(10), async { loop { tokio::select! { - new_vm_req_notif = new_vm_stream.next() => { - log::debug!("Got stream 1..."); - if let Some(new_vm_req_notif) = new_vm_req_notif { - match new_vm_req_notif { - Ok(new_vm_req_notif) => { - if new_vm_req_notif.action == surrealdb::Action::Update && !new_vm_req_notif.data.error.is_empty() { - return Ok::(Self::Error(vm_id.to_string(), new_vm_req_notif.data.error)); + error_notification = error_stream.next() => { + if let Some(err_notif) = error_notification { + match err_notif { + Ok(err_notif) => { + if err_notif.action == surrealdb::Action::Update + && !err_notif.data.error.is_empty() { + return Ok::( + Self::Error(vm_id.to_string(), + err_notif.data.error) + ); }; }, Err(e) => return Err(e.into()), @@ -233,6 +254,7 @@ impl NewVmResp { match args_notif { Ok(args_notif) => { if args_notif.action == surrealdb::Action::Create { + let _: Option = db.delete(("measurement_args", vm_id)).await?; return Ok(Self::Args(vm_id.to_string(), args_notif.data)); }; }, @@ -242,11 +264,12 @@ impl NewVmResp { } } } - }).await? + }) + .await? } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ActiveVm { pub id: RecordId, #[serde(rename = "in")] @@ -269,10 +292,32 @@ pub struct ActiveVm { } impl ActiveVm { + /// total hardware units of this VM + fn total_units(&self) -> u64 { + // TODO: Optimize this based on price of hardware. + // I tried, but this can be done better. + // Storage cost should also be based on tier + (self.vcpus as u64 * 10) + + ((self.memory_mb + 256) as u64 / 200) + + (self.disk_size_gb as u64 / 10) + + (!self.public_ipv4.is_empty() as u64 * 10) + } + + /// Returns price per minute in nanoLP + pub fn price_per_minute(&self) -> u64 { + self.total_units() * self.price_per_unit + } + + pub async fn get_by_uuid(db: &Surreal, uuid: &str) -> Result, Error> { + let contract: Option = + db.query(format!("select * from {ACTIVE_VM}:{uuid};")).await?.take(0)?; + Ok(contract) + } + pub async fn activate( db: &Surreal, id: &str, - args: detee_shared::vm_proto::MeasurementArgs, + args: MeasurementArgs, ) -> Result<(), Error> { let new_vm_req = match NewVmReq::get(db, id).await? { Some(r) => r, @@ -298,7 +343,7 @@ impl ActiveVm { let mut guest_ports = vec![22]; guest_ports.append(&mut args.exposed_ports.clone()); let mut i = 0; - while i < new_vm_req.extra_ports.len() && i < guest_ports.len() { + while i < args.exposed_ports.len() && i < guest_ports.len() { mapped_ports.push((args.exposed_ports[i], guest_ports[i])); i += 1; } @@ -327,6 +372,66 @@ impl ActiveVm { NewVmReq::delete(db, id).await?; Ok(()) } + + pub async fn update(db: &Surreal, id: &str) -> Result<(), Error> { + let update_vm_req = match UpdateVmReq::get(db, id).await? { + Some(r) => r, + None => return Ok(()), + }; + + let mut active_vm = match Self::get_by_uuid(db, id).await? { + Some(vm) => vm, + None => return Ok(()), + }; + + if update_vm_req.vcpus > 0 { + active_vm.vcpus = update_vm_req.vcpus; + } + if update_vm_req.memory_mb > 0 { + active_vm.memory_mb = update_vm_req.memory_mb; + } + if update_vm_req.disk_size_gb > 0 { + active_vm.disk_size_gb = update_vm_req.disk_size_gb; + } + if !update_vm_req.dtrfs_sha.is_empty() && !update_vm_req.kernel_sha.is_empty() { + active_vm.dtrfs_sha = update_vm_req.dtrfs_sha; + active_vm.kernel_sha = update_vm_req.kernel_sha; + } + + let _: Option = db.update(active_vm.id.clone()).content(active_vm).await?; + UpdateVmReq::delete(db, id).await?; + Ok(()) + } + + pub async fn change_hostname( + db: &Surreal, + id: &str, + new_hostname: &str, + ) -> Result { + let contract: Option = db + .query(format!( + "UPDATE {ACTIVE_VM}:{id} SET hostname = '{new_hostname}' RETURN BEFORE;" + )) + .await? + .take(0)?; + if let Some(contract) = contract { + if contract.hostname != new_hostname { + return Ok(true); + } + } + Ok(false) + } + + pub async fn delete(db: &Surreal, id: &str) -> Result { + let deleted_vm: Option = db.delete((ACTIVE_VM, id)).await?; + if let Some(deleted_vm) = deleted_vm { + let deleted_vm: DeletedVm = deleted_vm.into(); + let _: Vec = db.insert(DELETED_VM).relation(deleted_vm).await?; + Ok(true) + } else { + Ok(false) + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -344,8 +449,94 @@ pub struct UpdateVmReq { pub kernel_sha: String, pub kernel_url: String, pub created_at: Datetime, - pub price_per_unit: u64, - pub locked_nano: u64, + pub error: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UpdateVmEvent { + pub vm_id: RecordId, + #[serde(rename = "in")] + pub admin: RecordId, + #[serde(rename = "out")] + pub vm_node: RecordId, + pub disk_size_gb: u32, + pub vcpus: u32, + pub memory_mb: u32, + pub dtrfs_url: String, + pub dtrfs_sha: String, + pub kernel_sha: String, + pub kernel_url: String, + pub executed_at: Datetime, +} + +impl From for UpdateVmEvent { + fn from(update_vm_req: UpdateVmReq) -> Self { + Self { + vm_id: RecordId::from((VM_UPDATE_EVENT, update_vm_req.id.key().to_string())), + admin: update_vm_req.admin, + vm_node: update_vm_req.vm_node, + disk_size_gb: update_vm_req.disk_size_gb, + vcpus: update_vm_req.vcpus, + memory_mb: update_vm_req.memory_mb, + dtrfs_url: update_vm_req.dtrfs_url, + dtrfs_sha: update_vm_req.dtrfs_sha, + kernel_sha: update_vm_req.kernel_sha, + kernel_url: update_vm_req.kernel_url, + executed_at: Datetime::default(), + } + } +} + +impl UpdateVmReq { + pub async fn get(db: &Surreal, id: &str) -> Result, Error> { + let update_vm_req: Option = db.select((UPDATE_VM_REQ, id)).await?; + Ok(update_vm_req) + } + + pub async fn delete(db: &Surreal, id: &str) -> Result<(), Error> { + let update_vm_req: Option = db.delete((UPDATE_VM_REQ, id)).await?; + if let Some(update_vm_req) = update_vm_req { + let update_vm_event: UpdateVmEvent = update_vm_req.into(); + let _: Option = + db.create(VM_UPDATE_EVENT).content(update_vm_event).await?; + } + Ok(()) + } + + /// returns None if VM does not exist + /// returns Some(false) if hw update is not needed + /// returns Some(true) if hw update is needed and got submitted + /// returns error if something happened with the DB + pub async fn request_hw_update(mut self, db: &Surreal) -> Result, Error> { + let contract = ActiveVm::get_by_uuid(db, &self.id.key().to_string()).await?; + + if contract.is_none() { + return Ok(None); + } + let contract = contract.unwrap(); + // this is needed cause TryFrom does not support await + self.vm_node = contract.vm_node; + + if !((self.vcpus != 0 && contract.vcpus != self.vcpus) + || (self.memory_mb != 0 && contract.memory_mb != self.memory_mb) + || (!self.dtrfs_sha.is_empty() && contract.dtrfs_sha != self.dtrfs_sha) + || (self.disk_size_gb != 0 && contract.disk_size_gb != self.disk_size_gb)) + { + return Ok(Some(false)); + } + + let _: Vec = db.insert(UPDATE_VM_REQ).relation(self).await?; + Ok(Some(true)) + } + + pub async fn submit_error(db: &Surreal, id: &str, error: String) -> Result<(), Error> { + #[derive(Serialize)] + struct UpdateVmError { + error: String, + } + let _: Option = db.update((UPDATE_VM_REQ, id)).merge(UpdateVmError { error }).await?; + Ok(()) + } } #[derive(Debug, Serialize, Deserialize)] @@ -369,6 +560,28 @@ pub struct DeletedVm { pub price_per_unit: u64, } +impl From for DeletedVm { + fn from(active_vm: ActiveVm) -> Self { + Self { + id: RecordId::from((DELETED_VM, active_vm.id.key().to_string())), + admin: active_vm.admin, + vm_node: active_vm.vm_node, + hostname: active_vm.hostname, + mapped_ports: active_vm.mapped_ports, + public_ipv4: active_vm.public_ipv4, + public_ipv6: active_vm.public_ipv6, + disk_size_gb: active_vm.disk_size_gb, + vcpus: active_vm.vcpus, + memory_mb: active_vm.memory_mb, + dtrfs_sha: active_vm.dtrfs_sha, + kernel_sha: active_vm.kernel_sha, + created_at: active_vm.created_at, + deleted_at: Datetime::default(), + price_per_unit: active_vm.price_per_unit, + } + } +} + impl DeletedVm { pub async fn get_by_uuid(db: &Surreal, uuid: &str) -> Result, Error> { let contract: Option = @@ -431,24 +644,6 @@ impl DeletedVm { } } -impl ActiveVm { - /// total hardware units of this VM - fn total_units(&self) -> u64 { - // TODO: Optimize this based on price of hardware. - // I tried, but this can be done better. - // Storage cost should also be based on tier - (self.vcpus as u64 * 10) - + ((self.memory_mb + 256) as u64 / 200) - + (self.disk_size_gb as u64 / 10) - + (!self.public_ipv4.is_empty() as u64 * 10) - } - - /// Returns price per minute in nanoLP - pub fn price_per_minute(&self) -> u64 { - self.total_units() * self.price_per_unit - } -} - #[derive(Debug, Serialize, Deserialize)] pub struct ActiveVmWithNode { pub id: RecordId, diff --git a/src/grpc/types.rs b/src/grpc/types.rs index 8690908..54cd886 100644 --- a/src/grpc/types.rs +++ b/src/grpc/types.rs @@ -61,15 +61,45 @@ impl From for NewVmReq { } } -impl From for NewVmResp { - fn from(resp: db::NewVmResp) -> Self { +impl From for NewVmResp { + fn from(resp: db::WrappedMeasurement) -> Self { match resp { - // TODO: This will require a small architecture change to pass MeasurementArgs from - // Daemon to CLI - db::NewVmResp::Args(uuid, args) => { - NewVmResp { uuid, error: String::new(), args: Some(args) } + db::WrappedMeasurement::Args(uuid, args) => { + Self { uuid, error: String::new(), args: Some(args) } } - db::NewVmResp::Error(uuid, error) => NewVmResp { uuid, error, args: None }, + db::WrappedMeasurement::Error(uuid, error) => NewVmResp { uuid, error, args: None }, + } + } +} + +// TODO: NewVmResp is identical to UpdateVmResp so we can actually remove it from proto +impl From for UpdateVmResp { + fn from(resp: db::WrappedMeasurement) -> Self { + match resp { + db::WrappedMeasurement::Args(uuid, args) => { + Self { uuid, error: String::new(), args: Some(args) } + } + db::WrappedMeasurement::Error(uuid, error) => Self { uuid, error, args: None }, + } + } +} + +impl From for db::UpdateVmReq { + fn from(new_vm_req: UpdateVmReq) -> Self { + Self { + id: RecordId::from((NEW_VM_REQ, new_vm_req.uuid)), + admin: RecordId::from((ACCOUNT, new_vm_req.admin_pubkey)), + // vm_node gets modified later, and only if the db::UpdateVmReq is required + vm_node: RecordId::from((VM_NODE, String::new())), + disk_size_gb: new_vm_req.disk_size_gb, + vcpus: new_vm_req.vcpus, + memory_mb: new_vm_req.memory_mb, + kernel_url: new_vm_req.kernel_url, + kernel_sha: new_vm_req.kernel_sha, + dtrfs_url: new_vm_req.dtrfs_url, + dtrfs_sha: new_vm_req.dtrfs_sha, + created_at: surrealdb::sql::Datetime::default(), + error: String::new(), } } } diff --git a/src/grpc/vm.rs b/src/grpc/vm.rs index 49f9eff..ebbd218 100644 --- a/src/grpc/vm.rs +++ b/src/grpc/vm.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use crate::constants::{ACCOUNT, VM_NODE}; +use crate::constants::{ACCOUNT, NEW_VM_REQ, UPDATE_VM_REQ, VM_NODE}; use crate::db::prelude as db; use crate::grpc::{check_sig_from_parts, check_sig_from_req}; use detee_shared::common_proto::Empty; @@ -148,8 +148,6 @@ impl BrainVmDaemon for VmDaemonServer { match daemon_message { Ok(msg) => match msg.msg { Some(vm_daemon_message::Msg::NewVmResp(new_vm_resp)) => { - // TODO: move new_vm_req to active_vm - // also handle failure properly if !new_vm_resp.error.is_empty() { db::NewVmReq::submit_error( &self.db, @@ -170,9 +168,24 @@ impl BrainVmDaemon for VmDaemonServer { } } } - Some(vm_daemon_message::Msg::UpdateVmResp(_update_vm_resp)) => { - todo!(); - // self.data.submit_updatevm_resp(update_vm_resp).await; + Some(vm_daemon_message::Msg::UpdateVmResp(update_vm_resp)) => { + if !update_vm_resp.error.is_empty() { + db::UpdateVmReq::submit_error( + &self.db, + &update_vm_resp.uuid, + update_vm_resp.error, + ) + .await?; + } else { + db::upsert_record( + &self.db, + "measurement_args", + &update_vm_resp.uuid, + update_vm_resp.args.clone(), + ) + .await?; + db::ActiveVm::update(&self.db, &update_vm_resp.uuid).await?; + } } Some(vm_daemon_message::Msg::VmNodeResources(node_resources)) => { let node_resources: db::VmNodeResources = node_resources.into(); @@ -207,28 +220,29 @@ impl BrainVmCli for VmCliServer { async fn new_vm(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; info!("New VM requested via CLI: {req:?}"); - if db::general::Account::is_banned_by_node(&self.db, &req.admin_pubkey, &req.node_pubkey).await? { + if db::Account::is_banned_by_node(&self.db, &req.admin_pubkey, &req.node_pubkey).await? { return Err(Status::permission_denied("This operator banned you. What did you do?")); } - let new_vm_req: db::NewVmReq = req.into(); - let id = new_vm_req.id.key().to_string(); - - let db = self.db.clone(); + let db_req: db::NewVmReq = req.into(); + let id = db_req.id.key().to_string(); let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel(); + let db = self.db.clone(); tokio::spawn(async move { - let _ = oneshot_tx.send(db::NewVmResp::listen(&db, &id).await); + let _ = oneshot_tx.send(db::WrappedMeasurement::listen(&db, &id, NEW_VM_REQ).await); }); - new_vm_req.submit(&self.db).await?; + db_req.submit(&self.db).await?; match oneshot_rx.await { - Ok(Err(db::Error::TimeOut(_))) => Err(Status::deadline_exceeded("Request failed due to timeout. Please try again later or contact the DeTEE devs team.")), + Ok(Err(db::Error::TimeOut(_))) => Err(Status::deadline_exceeded( + "Network timeout. Please try again later or contact the DeTEE devs team.", + )), Ok(new_vm_resp) => Ok(Response::new(new_vm_resp?.into())), Err(e) => { - log::error!("Something weird happened. Reached error {e:?}"); + log::error!("Something weird happened during CLI NewVmReq. Reached error {e:?}"); Err(Status::unknown( - "Request failed due to unknown error. Please try again or contact the DeTEE devs team.", + "Unknown error. Please try again or contact the DeTEE devs team.", )) } } @@ -237,18 +251,62 @@ impl BrainVmCli for VmCliServer { async fn update_vm(&self, req: Request) -> Result, Status> { let req = check_sig_from_req(req)?; info!("Update VM requested via CLI: {req:?}"); - todo!(); - // let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel(); - // self.data.submit_updatevm_req(req, oneshot_tx).await; - // match oneshot_rx.await { - // Ok(response) => { - // info!("Sending UpdateVMResp: {response:?}"); - // Ok(Response::new(response)) - // } - // Err(e) => Err(Status::unknown(format!( - // "Update VM request failed due to error: {e}" - // ))), - // } + + let db_req: db::UpdateVmReq = req.clone().into(); + let id = db_req.id.key().to_string(); + + let mut hostname_changed = false; + if !req.hostname.is_empty() { + hostname_changed = + db::ActiveVm::change_hostname(&self.db, &req.uuid, &req.hostname).await?; + } + + let hw_change_submitted = db_req.request_hw_update(&self.db).await?; + if hw_change_submitted.is_none() { + return Ok(Response::new(UpdateVmResp { + uuid: req.uuid.clone(), + error: "VM Contract does not exist.".to_string(), + args: None, + })); + } + let hw_change_needed = hw_change_submitted.unwrap(); + + if !hostname_changed && !hw_change_needed { + return Ok(Response::new(UpdateVmResp { + uuid: req.uuid.clone(), + error: "No modification required".to_string(), + args: None, + })); + } + + // if only the hostname got changed, return a confirmation + if !hw_change_needed { + return Ok(Response::new(UpdateVmResp { + uuid: req.uuid.clone(), + error: String::new(), + args: None, + })); + } + + // if HW changes got requested, wait for the new args + let (oneshot_tx, oneshot_rx) = tokio::sync::oneshot::channel(); + let db = self.db.clone(); + tokio::spawn(async move { + let _ = oneshot_tx.send(db::WrappedMeasurement::listen(&db, &id, UPDATE_VM_REQ).await); + }); + + match oneshot_rx.await { + Ok(Err(db::Error::TimeOut(_))) => Err(Status::deadline_exceeded( + "Network timeout. Please try again later or contact the DeTEE devs team.", + )), + Ok(new_vm_resp) => Ok(Response::new(new_vm_resp?.into())), + Err(e) => { + log::error!("Something weird happened during CLI VM Update. Reached error {e:?}"); + Err(Status::unknown( + "Unknown error. Please try again or contact the DeTEE devs team.", + )) + } + } } async fn extend_vm(&self, req: Request) -> Result, Status> { @@ -264,12 +322,11 @@ impl BrainVmCli for VmCliServer { } async fn delete_vm(&self, req: Request) -> Result, Status> { - let _req = check_sig_from_req(req)?; - todo!(); - // match self.data.delete_vm(req).await { - // Ok(()) => Ok(Response::new(Empty {})), - // Err(e) => Err(Status::not_found(e.to_string())), - // } + let req = check_sig_from_req(req)?; + match db::ActiveVm::delete(&self.db, &req.uuid).await? { + true => Ok(Response::new(Empty {})), + false => Err(Status::not_found(format!("Could not find VM contract {}", &req.uuid))), + } } async fn list_vm_contracts( diff --git a/tests/common/prepare_test_env.rs b/tests/common/prepare_test_env.rs index d3e8b06..85f633e 100644 --- a/tests/common/prepare_test_env.rs +++ b/tests/common/prepare_test_env.rs @@ -5,10 +5,8 @@ use dotenv::dotenv; use hyper_util::rt::TokioIo; use std::net::SocketAddr; use std::sync::Arc; -use surreal_brain::grpc::{ - general::GeneralCliServer, - vm::{VmCliServer, VmDaemonServer}, -}; +use surreal_brain::grpc::general::GeneralCliServer; +use surreal_brain::grpc::vm::{VmCliServer, VmDaemonServer}; use surrealdb::engine::remote::ws::Client; use surrealdb::Surreal; use tokio::io::DuplexStream;