refactor db imports

usage of MeasurementArgs and VmNodeFilters from vm_proto in db module
fix test brain message add new ssh port on mock daemon while new vm
This commit is contained in:
Noor 2025-05-08 14:08:57 +05:30
parent 069bb27192
commit acce76197d
Signed by: noormohammedb
GPG Key ID: D83EFB8B3B967146
5 changed files with 15 additions and 16 deletions

@ -31,7 +31,6 @@ pub mod prelude {
pub use super::general::*; pub use super::general::*;
pub use super::vm::*; pub use super::vm::*;
pub use super::*; pub use super::*;
pub use detee_shared::snp::pb::vm_proto::{MeasurementArgs, VmNodeFilters};
} }
pub async fn db_connection( pub async fn db_connection(

@ -5,8 +5,9 @@ use super::Error;
use crate::constants::{ use crate::constants::{
ACCOUNT, ACTIVE_VM, DELETED_VM, NEW_VM_REQ, UPDATE_VM_REQ, VM_NODE, VM_UPDATE_EVENT, ACCOUNT, ACTIVE_VM, DELETED_VM, NEW_VM_REQ, UPDATE_VM_REQ, VM_NODE, VM_UPDATE_EVENT,
}; };
use crate::db::{Account, MeasurementArgs, Report, VmNodeFilters}; use crate::db::{Account, Report};
use crate::old_brain; use crate::old_brain;
use detee_shared::vm_proto;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use surrealdb::engine::remote::ws::Client; use surrealdb::engine::remote::ws::Client;
use surrealdb::sql::Datetime; use surrealdb::sql::Datetime;
@ -83,7 +84,7 @@ impl VmNodeWithReports {
// https://en.wikipedia.org/wiki/Dependency_inversion_principle // https://en.wikipedia.org/wiki/Dependency_inversion_principle
pub async fn find_by_filters( pub async fn find_by_filters(
db: &Surreal<Client>, db: &Surreal<Client>,
filters: VmNodeFilters, filters: vm_proto::VmNodeFilters,
) -> Result<Vec<Self>, Error> { ) -> Result<Vec<Self>, Error> {
let mut query = format!( let mut query = format!(
"select *, <-report.* as reports from {VM_NODE} where "select *, <-report.* as reports from {VM_NODE} where
@ -197,7 +198,7 @@ impl NewVmReq {
/// first string is the vm_id /// first string is the vm_id
pub enum WrappedMeasurement { pub enum WrappedMeasurement {
Args(String, MeasurementArgs), Args(String, vm_proto::MeasurementArgs),
Error(String, String), Error(String, String),
} }
@ -224,9 +225,10 @@ impl WrappedMeasurement {
)) ))
.await?; .await?;
let mut error_stream = resp.stream::<Notification<ErrorMessage>>(0)?; let mut error_stream = resp.stream::<Notification<ErrorMessage>>(0)?;
let mut args_stream = resp.stream::<Notification<MeasurementArgs>>(1)?; let mut args_stream = resp.stream::<Notification<vm_proto::MeasurementArgs>>(1)?;
let args: Option<MeasurementArgs> = db.delete(("measurement_args", vm_id)).await?; let args: Option<vm_proto::MeasurementArgs> =
db.delete(("measurement_args", vm_id)).await?;
if let Some(args) = args { if let Some(args) = args {
return Ok(Self::Args(vm_id.to_string(), args)); return Ok(Self::Args(vm_id.to_string(), args));
} }
@ -255,7 +257,7 @@ impl WrappedMeasurement {
match args_notif { match args_notif {
Ok(args_notif) => { Ok(args_notif) => {
if args_notif.action == surrealdb::Action::Create { if args_notif.action == surrealdb::Action::Create {
let _: Option<MeasurementArgs> = db.delete(("measurement_args", vm_id)).await?; let _: Option<vm_proto::MeasurementArgs> = db.delete(("measurement_args", vm_id)).await?;
return Ok(Self::Args(vm_id.to_string(), args_notif.data)); return Ok(Self::Args(vm_id.to_string(), args_notif.data));
}; };
}, },
@ -318,7 +320,7 @@ impl ActiveVm {
pub async fn activate( pub async fn activate(
db: &Surreal<Client>, db: &Surreal<Client>,
id: &str, id: &str,
args: MeasurementArgs, args: vm_proto::MeasurementArgs,
) -> Result<(), Error> { ) -> Result<(), Error> {
let new_vm_req = match NewVmReq::get(db, id).await? { let new_vm_req = match NewVmReq::get(db, id).await? {
Some(r) => r, Some(r) => r,

@ -107,9 +107,10 @@ pub async fn daemon_engine(
while let Some(brain_msg) = rx.recv().await { while let Some(brain_msg) = rx.recv().await {
match brain_msg.msg { match brain_msg.msg {
Some(vm_proto::brain_vm_message::Msg::NewVmReq(new_vm_req)) => { Some(vm_proto::brain_vm_message::Msg::NewVmReq(new_vm_req)) => {
let exposed_ports = [vec![22], new_vm_req.extra_ports].concat();
let args = Some(vm_proto::MeasurementArgs { let args = Some(vm_proto::MeasurementArgs {
dtrfs_api_endpoint: String::from("184.107.169.199:48865"), dtrfs_api_endpoint: String::from("184.107.169.199:48865"),
exposed_ports: new_vm_req.extra_ports, exposed_ports,
ovmf_hash: String::from( ovmf_hash: String::from(
"0346619257269b9a61ee003e197d521b8e2283483070d163a34940d6a1d40d76", "0346619257269b9a61ee003e197d521b8e2283483070d163a34940d6a1d40d76",
), ),

@ -51,7 +51,7 @@ async fn test_vm_creation_timeout() {
assert_eq!( assert_eq!(
timeout_error.message(), timeout_error.message(),
"Request failed due to timeout. Please try again later or contact the DeTEE devs team." "Network timeout. Please try again later or contact the DeTEE devs team.",
) )
} }

@ -25,7 +25,7 @@ async fn test_reg_vm_node() {
#[tokio::test] #[tokio::test]
async fn test_brain_message() { async fn test_brain_message() {
env_logger::builder().filter_level(log::LevelFilter::Error).init(); env_logger::builder().filter_level(log::LevelFilter::Error).init();
let db = prepare_test_db().await.unwrap(); prepare_test_db().await.unwrap();
let brain_channel = run_service_for_stream().await.unwrap(); let brain_channel = run_service_for_stream().await.unwrap();
let daemon_key = mock_vm_daemon(&brain_channel).await.unwrap(); let daemon_key = mock_vm_daemon(&brain_channel).await.unwrap();
@ -46,9 +46,6 @@ async fn test_brain_message() {
assert!(new_vm_resp.error.is_empty()); assert!(new_vm_resp.error.is_empty());
assert!(new_vm_resp.uuid.len() == 40); assert!(new_vm_resp.uuid.len() == 40);
assert!(new_vm_resp.args.is_some());
let id = ("measurement_args", new_vm_resp.uuid); assert!(new_vm_resp.args.unwrap().exposed_ports.len() == 3);
let data_in_db: detee_shared::vm_proto::MeasurementArgs = db.select(id).await.unwrap().unwrap();
assert_eq!(data_in_db, new_vm_resp.args.unwrap());
} }