diff --git a/src/Enclave.edl b/src/Enclave.edl index 48f81a85..ca08ec0e 100644 --- a/src/Enclave.edl +++ b/src/Enclave.edl @@ -107,6 +107,8 @@ enclave { void* occlum_ocall_posix_memalign(size_t alignment, size_t size); void occlum_ocall_free([user_check] void* ptr); + int occlum_ocall_mprotect([user_check] void* addr, size_t len, int prot); + void occlum_ocall_sched_yield(void); int occlum_ocall_sched_setaffinity( int host_tid, diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index 4bfa5a28..5c9f075b 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -745,8 +745,9 @@ fn do_mremap( Ok(addr as isize) } -fn do_mprotect(addr: usize, len: usize, prot: u32) -> Result { - // TODO: implement it +fn do_mprotect(addr: usize, len: usize, perms: u32) -> Result { + let perms = VMPerms::from_u32(perms as u32)?; + vm::do_mprotect(addr, len, perms)?; Ok(0) } diff --git a/src/libos/src/vm/mod.rs b/src/libos/src/vm/mod.rs index 7c1d8968..f3d80337 100644 --- a/src/libos/src/vm/mod.rs +++ b/src/libos/src/vm/mod.rs @@ -5,15 +5,18 @@ use std::fmt; mod process_vm; mod user_space_vm; +mod vm_area; mod vm_layout; mod vm_manager; +mod vm_perms; mod vm_range; use self::vm_layout::VMLayout; use self::vm_manager::{VMManager, VMMapOptionsBuilder}; -pub use self::process_vm::{MMapFlags, MRemapFlags, ProcessVM, ProcessVMBuilder, VMPerms}; +pub use self::process_vm::{MMapFlags, MRemapFlags, ProcessVM, ProcessVMBuilder}; pub use self::user_space_vm::USER_SPACE_VM_MANAGER; +pub use self::vm_perms::VMPerms; pub use self::vm_range::VMRange; pub fn do_mmap( @@ -63,6 +66,16 @@ pub fn do_mremap( current_vm.mremap(old_addr, old_size, new_size, flags) } +pub fn do_mprotect(addr: usize, size: usize, perms: VMPerms) -> Result<()> { + debug!( + "mprotect: addr: {:#x}, size: {:#x}, perms: {:?}", + addr, size, perms + ); + let current = current!(); + let mut current_vm = current.vm().lock().unwrap(); + current_vm.mprotect(addr, size, perms) +} + pub fn do_brk(addr: usize) -> Result { debug!("brk: addr: {:#x}", addr); let current = current!(); diff --git a/src/libos/src/vm/process_vm.rs b/src/libos/src/vm/process_vm.rs index 3a375709..19840595 100644 --- a/src/libos/src/vm/process_vm.rs +++ b/src/libos/src/vm/process_vm.rs @@ -6,6 +6,7 @@ use super::user_space_vm::{UserSpaceVMManager, UserSpaceVMRange, USER_SPACE_VM_M use super::vm_manager::{ VMInitializer, VMManager, VMMapAddr, VMMapOptions, VMMapOptionsBuilder, VMRemapOptions, }; +use super::vm_perms::VMPerms; #[derive(Debug)] pub struct ProcessVMBuilder<'a, 'b> { @@ -202,12 +203,18 @@ impl<'a, 'b> ProcessVMBuilder<'a, 'b> { /// The per-process virtual memory #[derive(Debug)] pub struct ProcessVM { - process_range: UserSpaceVMRange, + mmap_manager: VMManager, elf_ranges: Vec, heap_range: VMRange, stack_range: VMRange, brk: usize, - mmap_manager: VMManager, + // Memory safety notes: the process_range field must be the last one. + // + // Rust drops fields in the same order as they are declared. So by making + // process_range the last field, we ensure that when all other fields are + // dropped, their drop methods (if provided) can still access the memory + // region represented by the process_range field. + process_range: UserSpaceVMRange, } impl Default for ProcessVM { @@ -313,6 +320,7 @@ impl ProcessVM { let mmap_options = VMMapOptionsBuilder::default() .size(size) .addr(addr_option) + .perms(perms) .initializer(initializer) .build()?; let mmap_addr = self.mmap_manager.mmap(&mmap_options)?; @@ -340,6 +348,20 @@ impl ProcessVM { self.mmap_manager.munmap(addr, size) } + pub fn mprotect(&mut self, addr: usize, size: usize, perms: VMPerms) -> Result<()> { + let protect_range = VMRange::new_with_size(addr, size)?; + if !self.process_range.range().is_superset_of(&protect_range) { + return_errno!(ENOMEM, "invalid range"); + } + // TODO: support mprotect vm regions in addition to mmap + if !self.mmap_manager.range().is_superset_of(&protect_range) { + warn!("Do not support mprotect memory outside the mmap region yet"); + return Ok(()); + } + + self.mmap_manager.mprotect(addr, size, perms) + } + pub fn find_mmap_region(&self, addr: usize) -> Result<&VMRange> { self.mmap_manager.find_mmap_region(addr) } @@ -412,32 +434,6 @@ impl Default for MRemapFlags { } } -bitflags! { - pub struct VMPerms : u32 { - const READ = 0x1; - const WRITE = 0x2; - const EXEC = 0x4; - } -} - -impl VMPerms { - pub fn can_read(&self) -> bool { - self.contains(VMPerms::READ) - } - - pub fn can_write(&self) -> bool { - self.contains(VMPerms::WRITE) - } - - pub fn can_execute(&self) -> bool { - self.contains(VMPerms::EXEC) - } - - pub fn from_u32(bits: u32) -> Result { - VMPerms::from_bits(bits).ok_or_else(|| errno!(EINVAL, "unknown permission bits")) - } -} - unsafe fn fill_zeros(addr: usize, size: usize) { let ptr = addr as *mut u8; let buf = std::slice::from_raw_parts_mut(ptr, size); diff --git a/src/libos/src/vm/vm_area.rs b/src/libos/src/vm/vm_area.rs new file mode 100644 index 00000000..6ae22c5a --- /dev/null +++ b/src/libos/src/vm/vm_area.rs @@ -0,0 +1,51 @@ +use std::ops::{Deref, DerefMut}; + +use super::vm_perms::VMPerms; +use super::vm_range::VMRange; +use super::*; + +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub struct VMArea { + range: VMRange, + perms: VMPerms, +} + +impl VMArea { + pub fn new(range: VMRange, perms: VMPerms) -> Self { + Self { range, perms } + } + + pub fn perms(&self) -> VMPerms { + self.perms + } + + pub fn range(&self) -> &VMRange { + &self.range + } + + pub fn set_perms(&mut self, new_perms: VMPerms) { + self.perms = new_perms; + } + + pub fn subtract(&self, other: &VMRange) -> Vec { + self.deref() + .subtract(other) + .iter() + .map(|range| VMArea::new(*range, self.perms())) + .collect() + } +} + +impl Deref for VMArea { + type Target = VMRange; + + fn deref(&self) -> &Self::Target { + &self.range + } +} + +impl DerefMut for VMArea { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.range + } +} diff --git a/src/libos/src/vm/vm_manager.rs b/src/libos/src/vm/vm_manager.rs index 024b9fb1..a0b7f642 100644 --- a/src/libos/src/vm/vm_manager.rs +++ b/src/libos/src/vm/vm_manager.rs @@ -1,5 +1,8 @@ use super::*; +use super::vm_area::VMArea; +use super::vm_perms::VMPerms; + #[derive(Clone, Debug)] pub enum VMInitializer { DoNothing(), @@ -66,6 +69,7 @@ impl Default for VMMapAddr { pub struct VMMapOptions { size: usize, align: usize, + perms: VMPerms, addr: VMMapAddr, initializer: VMInitializer, } @@ -89,6 +93,9 @@ impl VMMapOptionsBuilder { } align }; + let perms = self + .perms + .ok_or_else(|| errno!(EINVAL, "perms must be given"))?; let addr = { let addr = self.addr.unwrap_or_default(); match addr { @@ -113,6 +120,7 @@ impl VMMapOptionsBuilder { Ok(VMMapOptions { size, align, + perms, addr, initializer, }) @@ -128,6 +136,10 @@ impl VMMapOptions { &self.addr } + pub fn perms(&self) -> &VMPerms { + &self.perms + } + pub fn initializer(&self) -> &VMInitializer { &self.initializer } @@ -201,52 +213,61 @@ impl VMRemapOptions { /// /// # Invariants /// -/// Behind the scene, VMManager maintains a list of VMRange that have been allocated. -/// (denoted as `self.sub_ranges`). To reason about the correctness of VMManager, we give +/// Behind the scene, VMManager maintains a list of VMArea that have been allocated. +/// (denoted as `self.vmas`). To reason about the correctness of VMManager, we give /// the set of invariants hold by VMManager. /// /// 1. The rule of sentry: /// ``` -/// self.range.begin() == self.sub_ranges[0].start() == self.sub_ranges[0].end() +/// self.range.begin() == self.vmas[0].start() == self.vmas[0].end() /// ``` /// and /// ``` -/// self.range.end() == self.sub_ranges[N-1].start() == self.sub_ranges[N-1].end() +/// self.range.end() == self.vmas[N-1].start() == self.vmas[N-1].end() /// ``` -/// where `N = self.sub_ranges.len()`. +/// where `N = self.vmas.len()`. /// /// 2. The rule of non-emptyness: /// ``` -/// self.sub_ranges[i].size() > 0, for 1 <= i < self.sub_ranges.len() - 1 +/// self.vmas[i].size() > 0, for 1 <= i < self.vmas.len() - 1 /// ``` /// /// 3. The rule of ordering: /// ``` -/// self.sub_ranges[i].end() <= self.sub_ranges[i+1].start() for 0 <= i < self.sub_ranges.len() - 1 +/// self.vmas[i].end() <= self.vmas[i+1].start() for 0 <= i < self.vmas.len() - 1 /// ``` /// /// 4. The rule of non-mergablility: /// ``` -/// self.sub_ranges[i].end() != self.sub_ranges[i+1].start() for 1 <= i < self.sub_ranges.len() - 2 +/// self.vmas[i].end() != self.vmas[i+1].start() || self.vmas[i].perms() != self.vmas[i+1].perms() +/// for 1 <= i < self.vmas.len() - 2 /// ``` /// #[derive(Debug, Default)] pub struct VMManager { range: VMRange, - sub_ranges: Vec, + vmas: Vec, } impl VMManager { pub fn from(addr: usize, size: usize) -> Result { let range = VMRange::new(addr, addr + size)?; - let sub_ranges = { + let vmas = { let start = range.start(); let end = range.end(); - let start_sentry = VMRange::new(start, start)?; - let end_sentry = VMRange::new(end, end)?; + let start_sentry = { + let range = VMRange::new_empty(start)?; + let perms = VMPerms::empty(); + VMArea::new(range, perms) + }; + let end_sentry = { + let range = VMRange::new_empty(end)?; + let perms = VMPerms::empty(); + VMArea::new(range, perms) + }; vec![start_sentry, end_sentry] }; - Ok(VMManager { range, sub_ranges }) + Ok(VMManager { range, vmas }) } pub fn range(&self) -> &VMRange { @@ -262,21 +283,24 @@ impl VMManager { self.munmap(addr, size)?; } - // Allocate a new subrange for this mmap request - let (insert_idx, free_subrange) = self.find_free_subrange(size, addr)?; - let new_subrange = self.alloc_subrange_from(size, addr, &free_subrange); - let new_subrange_addr = new_subrange.start(); + // Allocate a new range for this mmap request + let (insert_idx, free_range) = self.find_free_range(size, addr)?; + let new_range = self.alloc_range_from(size, addr, &free_range); + let new_addr = new_range.start(); + let new_vma = VMArea::new(new_range, *options.perms()); - // Initialize the memory of the new subrange + // Initialize the memory of the new range unsafe { - let buf = new_subrange.as_slice_mut(); + let buf = new_vma.as_slice_mut(); options.initializer.init_slice(buf)?; } + // Set memory permissions + Self::apply_perms(&new_vma, new_vma.perms()); - // After initializing, we can safely insert the new subrange - self.insert_new_subrange(insert_idx, new_subrange); + // After initializing, we can safely insert the new VMA + self.insert_new_vma(insert_idx, new_vma); - Ok(new_subrange_addr) + Ok(new_addr) } pub fn munmap(&mut self, addr: usize, size: usize) -> Result<()> { @@ -301,24 +325,27 @@ impl VMManager { effective_munmap_range }; - let new_sub_ranges = self - .sub_ranges + let new_vmas = self + .vmas .iter() - .flat_map(|subrange| { - // Keep the two sentry subranges intact - if subrange.size() == 0 { - return vec![*subrange]; + .flat_map(|vma| { + // Keep the two sentry VMA intact + if vma.size() == 0 { + return vec![*vma]; } - let unmapped_subrange = match subrange.intersect(&munmap_range) { - None => return vec![*subrange], - Some(unmapped_subrange) => unmapped_subrange, + let intersection_range = match vma.intersect(&munmap_range) { + None => return vec![*vma], + Some(intersection_range) => intersection_range, }; - subrange.subtract(&unmapped_subrange) + // Reset memory permissions + Self::apply_perms(&intersection_range, VMPerms::default()); + + vma.subtract(&intersection_range) }) .collect(); - self.sub_ranges = new_sub_ranges; + self.vmas = new_vmas; Ok(()) } @@ -343,9 +370,15 @@ impl VMManager { SizeType::Growing }; - // The old range must not span over multiple sub-ranges - self.find_containing_subrange_idx(&old_range) - .ok_or_else(|| errno!(EFAULT, "invalid range"))?; + // Get the memory permissions of the old range + let perms = { + // The old range must be contained in one VMA + let idx = self + .find_containing_vma_idx(&old_range) + .ok_or_else(|| errno!(EFAULT, "invalid range"))?; + let containing_vma = &self.vmas[idx]; + containing_vma.perms() + }; // Implement mremap as one optional mmap followed by one optional munmap. // @@ -362,6 +395,7 @@ impl VMManager { let mmap_opts = VMMapOptionsBuilder::default() .size(new_size - old_size) .addr(VMMapAddr::Need(old_range.end())) + .perms(perms) .initializer(VMInitializer::FillZeros()) .build()?; let ret_addr = Some(old_addr); @@ -374,6 +408,7 @@ impl VMManager { let mmap_ops = VMMapOptionsBuilder::default() .size(prefered_new_range.size()) .addr(VMMapAddr::Need(prefered_new_range.start())) + .perms(perms) .initializer(VMInitializer::FillZeros()) .build()?; (Some(mmap_ops), Some(old_addr)) @@ -381,6 +416,7 @@ impl VMManager { let mmap_ops = VMMapOptionsBuilder::default() .size(new_size) .addr(VMMapAddr::Any) + .perms(perms) .initializer(VMInitializer::CopyFrom { range: old_range }) .build()?; // Cannot determine the returned address for now, which can only be obtained after calling mmap @@ -392,6 +428,7 @@ impl VMManager { let mmap_opts = VMMapOptionsBuilder::default() .size(new_size) .addr(VMMapAddr::Force(new_addr)) + .perms(perms) .initializer(VMInitializer::CopyFrom { range: old_range }) .build()?; let ret_addr = Some(new_addr); @@ -442,41 +479,109 @@ impl VMManager { Ok(ret_addr.unwrap()) } + pub fn mprotect(&mut self, addr: usize, size: usize, new_perms: VMPerms) -> Result<()> { + let protect_range = VMRange::new_with_size(addr, size)?; + + // FIXME: the current implementation requires the target range to be + // contained in exact one VMA. + let containing_idx = self + .find_containing_vma_idx(&protect_range) + .ok_or_else(|| errno!(ENOMEM, "invalid range"))?; + let containing_vma = &self.vmas[containing_idx]; + + let old_perms = containing_vma.perms(); + if new_perms == old_perms { + return Ok(()); + } + + let same_start = protect_range.start() == containing_vma.start(); + let same_end = protect_range.end() == containing_vma.end(); + let containing_vma = &mut self.vmas[containing_idx]; + match (same_start, same_end) { + (true, true) => { + containing_vma.set_perms(new_perms); + + Self::apply_perms(containing_vma, containing_vma.perms()); + } + (false, true) => { + containing_vma.set_end(protect_range.start()); + + let new_vma = VMArea::new(protect_range, new_perms); + Self::apply_perms(&new_vma, new_vma.perms()); + self.insert_new_vma(containing_idx + 1, new_vma); + } + (true, false) => { + containing_vma.set_start(protect_range.end()); + + let new_vma = VMArea::new(protect_range, new_perms); + Self::apply_perms(&new_vma, new_vma.perms()); + self.insert_new_vma(containing_idx, new_vma); + } + (false, false) => { + // The containing VMA is divided into three VMAs: + // Shrinked old VMA: [containing_vma.start, protect_range.start) + // New VMA: [protect_range.start, protect_range.end) + // Another new vma: [protect_range.end, containing_vma.end) + + let old_end = containing_vma.end(); + let protect_end = protect_range.end(); + + // Shrinked old VMA + containing_vma.set_end(protect_range.start()); + + // New VMA + let new_vma = VMArea::new(protect_range, new_perms); + Self::apply_perms(&new_vma, new_vma.perms()); + self.insert_new_vma(containing_idx + 1, new_vma); + + // Another new VMA + let new_vma2 = { + let range = VMRange::new(protect_end, old_end).unwrap(); + VMArea::new(range, old_perms) + }; + self.insert_new_vma(containing_idx + 2, new_vma2); + } + } + + Ok(()) + } + pub fn find_mmap_region(&self, addr: usize) -> Result<&VMRange> { - self.sub_ranges + self.vmas .iter() - .find(|subrange| subrange.contains(addr)) + .map(|vma| vma.range()) + .find(|vma| vma.contains(addr)) .ok_or_else(|| errno!(ESRCH, "no mmap regions that contains the address")) } - // Find a subrange that contains the given range and returns the index of the subrange - fn find_containing_subrange_idx(&self, target_range: &VMRange) -> Option { - self.sub_ranges + // Find a VMA that contains the given range, returning the VMA's index + fn find_containing_vma_idx(&self, target_range: &VMRange) -> Option { + self.vmas .iter() - .position(|subrange| subrange.is_superset_of(target_range)) + .position(|vma| vma.is_superset_of(target_range)) } // Returns whether the requested range is free fn is_free_range(&self, request_range: &VMRange) -> bool { self.range.is_superset_of(request_range) && self - .sub_ranges + .vmas .iter() .all(|range| range.overlap_with(request_range) == false) } - // Find the free subrange that satisfies the constraints of size and address - fn find_free_subrange(&self, size: usize, addr: VMMapAddr) -> Result<(usize, VMRange)> { + // Find the free range that satisfies the constraints of size and address + fn find_free_range(&self, size: usize, addr: VMMapAddr) -> Result<(usize, VMRange)> { // TODO: reduce the complexity from O(N) to O(log(N)), where N is - // the number of existing subranges. + // the number of existing VMAs. // Record the minimal free range that satisfies the contraints let mut result_free_range: Option = None; let mut result_idx: Option = None; - for (idx, range_pair) in self.sub_ranges.windows(2).enumerate() { - // Since we have two sentry sub_ranges at both ends, we can be sure that the free - // space only appears between two consecutive sub_ranges. + for (idx, range_pair) in self.vmas.windows(2).enumerate() { + // Since we have two sentry vmas at both ends, we can be sure that the free + // space only appears between two consecutive vmas. let pre_range = &range_pair[0]; let next_range = &range_pair[1]; @@ -539,68 +644,101 @@ impl VMManager { Ok((insert_idx, free_range)) } - fn alloc_subrange_from( - &self, - size: usize, - addr: VMMapAddr, - free_subrange: &VMRange, - ) -> VMRange { - debug_assert!(free_subrange.size() >= size); + fn alloc_range_from(&self, size: usize, addr: VMMapAddr, free_range: &VMRange) -> VMRange { + debug_assert!(free_range.size() >= size); - let mut new_subrange = *free_subrange; + let mut new_range = *free_range; if let VMMapAddr::Need(addr) = addr { - debug_assert!(addr == new_subrange.start()); + debug_assert!(addr == new_range.start()); } if let VMMapAddr::Force(addr) = addr { - debug_assert!(addr == new_subrange.start()); + debug_assert!(addr == new_range.start()); } - new_subrange.resize(size); - new_subrange + new_range.resize(size); + new_range } - // Insert the new sub-range, and when possible, merge it with its neighbors. - fn insert_new_subrange(&mut self, insert_idx: usize, new_subrange: VMRange) { - // New sub-range can only be inserted between the two sentry sub-ranges - debug_assert!(0 < insert_idx && insert_idx < self.sub_ranges.len()); + // Insert a new VMA, and when possible, merge it with its neighbors. + fn insert_new_vma(&mut self, insert_idx: usize, new_vma: VMArea) { + // New VMA can only be inserted between the two sentry VMAs + debug_assert!(0 < insert_idx && insert_idx < self.vmas.len()); let left_idx = insert_idx - 1; let right_idx = insert_idx; - // Double check the order - debug_assert!(self.sub_ranges[left_idx].end() <= new_subrange.start()); - debug_assert!(new_subrange.end() <= self.sub_ranges[right_idx].start()); + let left_vma = &self.vmas[left_idx]; + let right_vma = &self.vmas[right_idx]; - let left_mergable = if left_idx > 0 { - // Mergable if there is no gap between the left neighbor and the new sub-range - self.sub_ranges[left_idx].end() == new_subrange.start() - } else { - // The left sentry sub-range is NOT mergable with any sub-range - false - }; - let right_mergable = if right_idx < self.sub_ranges.len() - 1 { - // Mergable if there is no gap between the right neighbor and the new sub-range - self.sub_ranges[right_idx].start() == new_subrange.end() - } else { - // The right sentry sub-range is NOT mergable with any sub-range - false - }; + // Double check the order + debug_assert!(left_vma.end() <= new_vma.start()); + debug_assert!(new_vma.end() <= right_vma.start()); + + let left_mergable = Self::can_merge_vmas(left_vma, &new_vma); + let right_mergable = Self::can_merge_vmas(&new_vma, right_vma); + + drop(left_vma); + drop(right_vma); match (left_mergable, right_mergable) { (false, false) => { - self.sub_ranges.insert(insert_idx, new_subrange); + self.vmas.insert(insert_idx, new_vma); } (true, false) => { - self.sub_ranges[left_idx].end = new_subrange.end; + self.vmas[left_idx].set_end(new_vma.end); } (false, true) => { - self.sub_ranges[right_idx].start = new_subrange.start; + self.vmas[right_idx].set_start(new_vma.start); } (true, true) => { - self.sub_ranges[left_idx].end = self.sub_ranges[right_idx].end; - self.sub_ranges.remove(right_idx); + let left_new_end = self.vmas[right_idx].end(); + self.vmas[left_idx].set_end(left_new_end); + self.vmas.remove(right_idx); } } } + + fn can_merge_vmas(left: &VMArea, right: &VMArea) -> bool { + debug_assert!(left.end() <= right.start()); + + // Both of the two VMAs are not sentry (whose size == 0) + left.size() > 0 && right.size() > 0 && + // Two VMAs must border with each other + left.end() == right.start() && + // Two VMAs must have the same memory permissions + left.perms() == right.perms() + } + + fn apply_perms(protect_range: &VMRange, perms: VMPerms) { + extern "C" { + pub fn occlum_ocall_mprotect( + retval: *mut i32, + addr: *const c_void, + len: usize, + prot: i32, + ) -> sgx_status_t; + }; + + unsafe { + let mut retval = 0; + let addr = protect_range.start() as *const c_void; + let len = protect_range.size(); + let prot = perms.bits() as i32; + let sgx_status = occlum_ocall_mprotect(&mut retval, addr, len, prot); + assert!(sgx_status == sgx_status_t::SGX_SUCCESS && retval == 0); + } + } +} + +impl Drop for VMManager { + fn drop(&mut self) { + // Ensure that memory permissions are recovered + for vma in &self.vmas { + if vma.size() == 0 || vma.perms() == VMPerms::default() { + continue; + } + Self::apply_perms(vma, VMPerms::default()); + } + } } diff --git a/src/libos/src/vm/vm_perms.rs b/src/libos/src/vm/vm_perms.rs new file mode 100644 index 00000000..da4669a2 --- /dev/null +++ b/src/libos/src/vm/vm_perms.rs @@ -0,0 +1,34 @@ +use super::*; + +bitflags! { + pub struct VMPerms : u32 { + const READ = 0x1; + const WRITE = 0x2; + const EXEC = 0x4; + const ALL = Self::READ.bits | Self::WRITE.bits | Self::EXEC.bits; + } +} + +impl VMPerms { + pub fn from_u32(bits: u32) -> Result { + Self::from_bits(bits).ok_or_else(|| errno!(EINVAL, "invalid bits")) + } + + pub fn can_read(&self) -> bool { + self.contains(VMPerms::READ) + } + + pub fn can_write(&self) -> bool { + self.contains(VMPerms::WRITE) + } + + pub fn can_execute(&self) -> bool { + self.contains(VMPerms::EXEC) + } +} + +impl Default for VMPerms { + fn default() -> Self { + VMPerms::ALL + } +} diff --git a/src/libos/src/vm/vm_range.rs b/src/libos/src/vm/vm_range.rs index 91840ffc..1dbf6175 100644 --- a/src/libos/src/vm/vm_range.rs +++ b/src/libos/src/vm/vm_range.rs @@ -60,9 +60,20 @@ impl VMRange { } pub fn resize(&mut self, new_size: usize) { + debug_assert!(new_size % PAGE_SIZE == 0); self.end = self.start + new_size; } + pub fn set_start(&mut self, start: usize) { + debug_assert!(start % PAGE_SIZE == 0 && start <= self.end); + self.start = start; + } + + pub fn set_end(&mut self, end: usize) { + debug_assert!(end % PAGE_SIZE == 0 && end >= self.start); + self.end = end; + } + pub fn empty(&self) -> bool { self.start == self.end } @@ -75,36 +86,47 @@ impl VMRange { self.start() <= addr && addr < self.end() } + // Returns whether two ranges have non-empty interesection. pub fn overlap_with(&self, other: &VMRange) -> bool { - self.start() < other.end() && other.start() < self.end() + let intersection_start = self.start().max(other.start()); + let intersection_end = self.end().min(other.end()); + intersection_start < intersection_end } + // Returns a set of ranges by subtracting self with the other. + // + // Post-condition: the returned ranges have non-zero sizes. pub fn subtract(&self, other: &VMRange) -> Vec { + if self.size() == 0 { + return vec![]; + } + + let intersection = match self.intersect(other) { + None => return vec![*self], + Some(intersection) => intersection, + }; + let self_start = self.start(); let self_end = self.end(); - let other_start = other.start(); - let other_end = other.end(); + let inter_start = intersection.start(); + let inter_end = intersection.end(); + debug_assert!(self_start <= inter_start); + debug_assert!(inter_end <= self_end); - match (self_start < other_start, other_end < self_end) { + match (self_start < inter_start, inter_end < self_end) { (false, false) => Vec::new(), - (false, true) => unsafe { - vec![VMRange::from_unchecked(self_start.max(other_end), self_end)] - }, - (true, false) => unsafe { - vec![VMRange::from_unchecked( - self_start, - self_end.min(other_start), - )] - }, + (false, true) => unsafe { vec![VMRange::from_unchecked(inter_end, self_end)] }, + (true, false) => unsafe { vec![VMRange::from_unchecked(self_start, inter_start)] }, (true, true) => unsafe { vec![ - VMRange::from_unchecked(self_start, other_start), - VMRange::from_unchecked(other_end, self_end), + VMRange::from_unchecked(self_start, inter_start), + VMRange::from_unchecked(inter_end, self_end), ] }, } } + // Returns an non-empty intersection if where is any pub fn intersect(&self, other: &VMRange) -> Option { let intersection_start = self.start().max(other.start()); let intersection_end = self.end().min(other.end()); diff --git a/src/pal/src/ocalls/mem.c b/src/pal/src/ocalls/mem.c index 2f7d3cde..8d8f8fe7 100644 --- a/src/pal/src/ocalls/mem.c +++ b/src/pal/src/ocalls/mem.c @@ -1,4 +1,5 @@ #include +#include #include "ocalls.h" void *occlum_ocall_posix_memalign(size_t alignment, size_t size) { @@ -25,3 +26,7 @@ void *occlum_ocall_posix_memalign(size_t alignment, size_t size) { void occlum_ocall_free(void *ptr) { free(ptr); } + +int occlum_ocall_mprotect(void *addr, size_t len, int prot) { + return mprotect(addr, len, prot); +} diff --git a/test/mmap/main.c b/test/mmap/main.c index 30167b0a..53c2e491 100644 --- a/test/mmap/main.c +++ b/test/mmap/main.c @@ -733,6 +733,227 @@ int test_mremap_with_fixed_addr() { return 0; } +// ============================================================================ +// Test cases for mprotect +// ============================================================================ + +int test_mprotect_once() { + // The memory permissions initially looks like below: + // + // Pages: #0 #1 #2 #3 + // ------------------------------------- + // Memory perms: [ ][ ][ ][ ] + size_t total_len = 4; // in pages + int init_prot = PROT_NONE; + + // The four settings for mprotect and its resulting memory perms. + // + // Pages: #0 #1 #2 #3 + // ------------------------------------- + // Setting (i = 0): + // mprotect: [RW ][RW ][RW ][RW ] + // result: [RW ][RW ][RW ][RW ] + // Setting (i = 1): + // mprotect: [RW ] + // result: [RW ][ ][ ][ ] + // Setting (i = 2): + // mprotect: [RW ][RW ] + // result: [ ][ ][RW ][RW ] + // Setting (i = 3): + // mprotect: [RW ][RW ] + // result: [ ][RW ][RW ][ ] + size_t lens[] = { 4, 1, 2, 2}; // in pages + size_t offsets[] = { 0, 0, 2, 1}; // in pages + for (int i = 0; i < ARRAY_SIZE(lens); i++) { + int flags = MAP_PRIVATE | MAP_ANONYMOUS; + void *buf = mmap(NULL, total_len * PAGE_SIZE, init_prot, flags, -1, 0); + if (buf == MAP_FAILED) { + THROW_ERROR("mmap failed"); + } + + size_t len = lens[i] * PAGE_SIZE; + size_t offset = offsets[i] * PAGE_SIZE; + int prot = PROT_READ | PROT_WRITE; + void *tmp_buf = (char *)buf + offset; + int ret = mprotect(tmp_buf, len, prot); + if (ret < 0) { + THROW_ERROR("mprotect failed"); + } + + ret = munmap(buf, total_len * PAGE_SIZE); + if (ret < 0) { + THROW_ERROR("munmap failed"); + } + } + + return 0; +} + +int test_mprotect_twice() { + // The memory permissions initially looks like below: + // + // Pages: #0 #1 #2 #3 + // ------------------------------------- + // Memory perms: [ ][ ][ ][ ] + size_t total_len = 4; // in pages + int init_prot = PROT_NONE; + + // The four settings for mprotects and their results + // + // Pages: #0 #1 #2 #3 + // ------------------------------------- + // Setting (i = 0): + // mprotect (j = 0): [RW ][RW ] + // mprotect (j = 1): [RW ][RW ] + // result: [RW ][RW ][RW ][RW ] + // Setting (i = 1): + // mprotect (j = 0): [RW ] + // mprotect (j = 1): [RW ] + // result: [ ][RW ][ ][RW ] + // Setting (i = 2): + // mprotect (j = 0): [RW ][RW ] + // mprotect (j = 1): [ WX][ WX] + // result: [ ][ WX][ WX][ ] + // Setting (i = 3): + // mprotect (j = 0): [RW ][RW ] + // mprotect (j = 1): [ ] + // result: [ ][ ][RW ][ ] + size_t lens[][2] = { + { 2, 2 }, + { 1, 1 }, + { 2, 2 }, + { 2, 1 } + }; // in pages + size_t offsets[][2] = { + { 0, 2 }, + { 1, 3 }, + { 1, 1 }, + { 1, 1 } + }; // in pages + int prots[][2] = { + { PROT_READ | PROT_WRITE, PROT_READ | PROT_WRITE }, + { PROT_READ | PROT_WRITE, PROT_READ | PROT_WRITE }, + { PROT_READ | PROT_WRITE, PROT_WRITE | PROT_EXEC }, + { PROT_READ | PROT_WRITE, PROT_NONE } + }; + for (int i = 0; i < ARRAY_SIZE(lens); i++) { + int flags = MAP_PRIVATE | MAP_ANONYMOUS; + void *buf = mmap(NULL, total_len * PAGE_SIZE, init_prot, flags, -1, 0); + if (buf == MAP_FAILED) { + THROW_ERROR("mmap failed"); + } + + for (int j = 0; j < 2; j++) { + size_t len = lens[i][j] * PAGE_SIZE; + size_t offset = offsets[i][j] * PAGE_SIZE; + int prot = prots[i][j]; + void *tmp_buf = (char *)buf + offset; + int ret = mprotect(tmp_buf, len, prot); + if (ret < 0) { + THROW_ERROR("mprotect failed"); + } + } + + int ret = munmap(buf, total_len * PAGE_SIZE); + if (ret < 0) { + THROW_ERROR("munmap failed"); + } + } + return 0; +} + +int test_mprotect_triple() { + // The memory permissions initially looks like below: + // + // Pages: #0 #1 #2 #3 + // ------------------------------------- + // Memory perms: [RWX][RWX][RWX][RWX] + size_t total_len = 4; // in pages + int init_prot = PROT_READ | PROT_WRITE | PROT_EXEC; + + // The four settings for mprotects and their results + // + // Pages: #0 #1 #2 #3 + // ------------------------------------- + // Setting (i = 0): + // mprotect (j = 0): [ ][ ] + // mprotect (j = 1): [ ] + // mprotect (j = 2): [ ] + // result: [ ][ ][ ][ ] + size_t lens[][3] = { + { 2, 1, 1 }, + }; // in pages + size_t offsets[][3] = { + { 0, 3, 2 }, + }; // in pages + int prots[][3] = { + { PROT_NONE, PROT_NONE, PROT_NONE }, + }; + for (int i = 0; i < ARRAY_SIZE(lens); i++) { + int flags = MAP_PRIVATE | MAP_ANONYMOUS; + void *buf = mmap(NULL, total_len * PAGE_SIZE, init_prot, flags, -1, 0); + if (buf == MAP_FAILED) { + THROW_ERROR("mmap failed"); + } + + for (int j = 0; j < 3; j++) { + size_t len = lens[i][j] * PAGE_SIZE; + size_t offset = offsets[i][j] * PAGE_SIZE; + int prot = prots[i][j]; + void *tmp_buf = (char *)buf + offset; + int ret = mprotect(tmp_buf, len, prot); + if (ret < 0) { + THROW_ERROR("mprotect failed"); + } + } + + int ret = munmap(buf, total_len * PAGE_SIZE); + if (ret < 0) { + THROW_ERROR("munmap failed"); + } + } + return 0; +} + +int test_mprotect_with_zero_len() { + int flags = MAP_PRIVATE | MAP_ANONYMOUS; + void *buf = mmap(NULL, PAGE_SIZE, PROT_NONE, flags, -1, 0); + if (buf == MAP_FAILED) { + THROW_ERROR("mmap failed"); + } + + int ret = mprotect(buf, 0, PROT_NONE); + if (ret < 0) { + THROW_ERROR("mprotect failed"); + } + + ret = munmap(buf, PAGE_SIZE); + if (ret < 0) { + THROW_ERROR("munmap failed"); + } + + return 0; +} + +int test_mprotect_with_invalid_addr() { + int ret = mprotect(NULL, PAGE_SIZE, PROT_NONE); + if (ret == 0 || errno != ENOMEM) { + THROW_ERROR("using invalid addr should have failed"); + } + return 0; +} + +int test_mprotect_with_invalid_prot() { + int invalid_prot = 0x1234; // invalid protection bits + void *valid_addr = &invalid_prot; + size_t valid_len = PAGE_SIZE; + int ret = mprotect(valid_addr, valid_len, invalid_prot); + if (ret == 0 || errno != EINVAL) { + THROW_ERROR("using invalid addr should have failed"); + } + return 0; +} + // ============================================================================ // Test suite main // ============================================================================ @@ -762,6 +983,12 @@ static test_case_t test_cases[] = { TEST_CASE(test_mremap), TEST_CASE(test_mremap_subrange), TEST_CASE(test_mremap_with_fixed_addr), + TEST_CASE(test_mprotect_once), + TEST_CASE(test_mprotect_twice), + TEST_CASE(test_mprotect_triple), + TEST_CASE(test_mprotect_with_zero_len), + TEST_CASE(test_mprotect_with_invalid_addr), + TEST_CASE(test_mprotect_with_invalid_prot), }; int main() {