diff --git a/src/db/mod.rs b/src/db/mod.rs index f6968c4..6b3f6c4 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -31,7 +31,6 @@ pub mod prelude { pub use super::general::*; pub use super::vm::*; pub use super::*; - pub use detee_shared::snp::pb::vm_proto::{MeasurementArgs, VmNodeFilters}; } pub async fn db_connection( diff --git a/src/db/vm.rs b/src/db/vm.rs index 2da35e4..5641dea 100644 --- a/src/db/vm.rs +++ b/src/db/vm.rs @@ -5,8 +5,9 @@ use super::Error; use crate::constants::{ ACCOUNT, ACTIVE_VM, DELETED_VM, NEW_VM_REQ, UPDATE_VM_REQ, VM_NODE, VM_UPDATE_EVENT, }; -use crate::db::{Account, MeasurementArgs, Report, VmNodeFilters}; +use crate::db::{Account, Report}; use crate::old_brain; +use detee_shared::vm_proto; use serde::{Deserialize, Serialize}; use surrealdb::engine::remote::ws::Client; use surrealdb::sql::Datetime; @@ -83,7 +84,7 @@ impl VmNodeWithReports { // https://en.wikipedia.org/wiki/Dependency_inversion_principle pub async fn find_by_filters( db: &Surreal, - filters: VmNodeFilters, + filters: vm_proto::VmNodeFilters, ) -> Result, Error> { let mut query = format!( "select *, <-report.* as reports from {VM_NODE} where @@ -197,7 +198,7 @@ impl NewVmReq { /// first string is the vm_id pub enum WrappedMeasurement { - Args(String, MeasurementArgs), + Args(String, vm_proto::MeasurementArgs), Error(String, String), } @@ -224,9 +225,10 @@ impl WrappedMeasurement { )) .await?; let mut error_stream = resp.stream::>(0)?; - let mut args_stream = resp.stream::>(1)?; + let mut args_stream = resp.stream::>(1)?; - let args: Option = db.delete(("measurement_args", vm_id)).await?; + let args: Option = + db.delete(("measurement_args", vm_id)).await?; if let Some(args) = args { return Ok(Self::Args(vm_id.to_string(), args)); } @@ -255,7 +257,7 @@ impl WrappedMeasurement { match args_notif { Ok(args_notif) => { if args_notif.action == surrealdb::Action::Create { - let _: Option = db.delete(("measurement_args", vm_id)).await?; + let _: Option = db.delete(("measurement_args", vm_id)).await?; return Ok(Self::Args(vm_id.to_string(), args_notif.data)); }; }, @@ -318,7 +320,7 @@ impl ActiveVm { pub async fn activate( db: &Surreal, id: &str, - args: MeasurementArgs, + args: vm_proto::MeasurementArgs, ) -> Result<(), Error> { let new_vm_req = match NewVmReq::get(db, id).await? { Some(r) => r, diff --git a/tests/common/vm_daemon_utils.rs b/tests/common/vm_daemon_utils.rs index 818e47d..58a9446 100644 --- a/tests/common/vm_daemon_utils.rs +++ b/tests/common/vm_daemon_utils.rs @@ -107,9 +107,10 @@ pub async fn daemon_engine( while let Some(brain_msg) = rx.recv().await { match brain_msg.msg { Some(vm_proto::brain_vm_message::Msg::NewVmReq(new_vm_req)) => { + let exposed_ports = [vec![22], new_vm_req.extra_ports].concat(); let args = Some(vm_proto::MeasurementArgs { dtrfs_api_endpoint: String::from("184.107.169.199:48865"), - exposed_ports: new_vm_req.extra_ports, + exposed_ports, ovmf_hash: String::from( "0346619257269b9a61ee003e197d521b8e2283483070d163a34940d6a1d40d76", ), diff --git a/tests/grpc_vm_cli_test.rs b/tests/grpc_vm_cli_test.rs index d66be88..e1e1df4 100644 --- a/tests/grpc_vm_cli_test.rs +++ b/tests/grpc_vm_cli_test.rs @@ -51,7 +51,7 @@ async fn test_vm_creation_timeout() { assert_eq!( timeout_error.message(), - "Request failed due to timeout. Please try again later or contact the DeTEE devs team." + "Network timeout. Please try again later or contact the DeTEE devs team.", ) } diff --git a/tests/grpc_vm_daemon_test.rs b/tests/grpc_vm_daemon_test.rs index ef9264a..89d7983 100644 --- a/tests/grpc_vm_daemon_test.rs +++ b/tests/grpc_vm_daemon_test.rs @@ -25,7 +25,7 @@ async fn test_reg_vm_node() { #[tokio::test] async fn test_brain_message() { env_logger::builder().filter_level(log::LevelFilter::Error).init(); - let db = prepare_test_db().await.unwrap(); + prepare_test_db().await.unwrap(); let brain_channel = run_service_for_stream().await.unwrap(); let daemon_key = mock_vm_daemon(&brain_channel).await.unwrap(); @@ -46,9 +46,6 @@ async fn test_brain_message() { assert!(new_vm_resp.error.is_empty()); assert!(new_vm_resp.uuid.len() == 40); - - let id = ("measurement_args", new_vm_resp.uuid); - let data_in_db: detee_shared::vm_proto::MeasurementArgs = db.select(id).await.unwrap().unwrap(); - - assert_eq!(data_in_db, new_vm_resp.args.unwrap()); + assert!(new_vm_resp.args.is_some()); + assert!(new_vm_resp.args.unwrap().exposed_ports.len() == 3); }