Add mprotect system call

This commit is contained in:
Tate, Hongliang Tian 2020-06-16 16:38:14 +08:00
parent b9b9b1032c
commit bca0663972
10 changed files with 622 additions and 133 deletions

@ -107,6 +107,8 @@ enclave {
void* occlum_ocall_posix_memalign(size_t alignment, size_t size); void* occlum_ocall_posix_memalign(size_t alignment, size_t size);
void occlum_ocall_free([user_check] void* ptr); void occlum_ocall_free([user_check] void* ptr);
int occlum_ocall_mprotect([user_check] void* addr, size_t len, int prot);
void occlum_ocall_sched_yield(void); void occlum_ocall_sched_yield(void);
int occlum_ocall_sched_setaffinity( int occlum_ocall_sched_setaffinity(
int host_tid, int host_tid,

@ -745,8 +745,9 @@ fn do_mremap(
Ok(addr as isize) Ok(addr as isize)
} }
fn do_mprotect(addr: usize, len: usize, prot: u32) -> Result<isize> { fn do_mprotect(addr: usize, len: usize, perms: u32) -> Result<isize> {
// TODO: implement it let perms = VMPerms::from_u32(perms as u32)?;
vm::do_mprotect(addr, len, perms)?;
Ok(0) Ok(0)
} }

@ -5,15 +5,18 @@ use std::fmt;
mod process_vm; mod process_vm;
mod user_space_vm; mod user_space_vm;
mod vm_area;
mod vm_layout; mod vm_layout;
mod vm_manager; mod vm_manager;
mod vm_perms;
mod vm_range; mod vm_range;
use self::vm_layout::VMLayout; use self::vm_layout::VMLayout;
use self::vm_manager::{VMManager, VMMapOptionsBuilder}; use self::vm_manager::{VMManager, VMMapOptionsBuilder};
pub use self::process_vm::{MMapFlags, MRemapFlags, ProcessVM, ProcessVMBuilder, VMPerms}; pub use self::process_vm::{MMapFlags, MRemapFlags, ProcessVM, ProcessVMBuilder};
pub use self::user_space_vm::USER_SPACE_VM_MANAGER; pub use self::user_space_vm::USER_SPACE_VM_MANAGER;
pub use self::vm_perms::VMPerms;
pub use self::vm_range::VMRange; pub use self::vm_range::VMRange;
pub fn do_mmap( pub fn do_mmap(
@ -63,6 +66,16 @@ pub fn do_mremap(
current_vm.mremap(old_addr, old_size, new_size, flags) current_vm.mremap(old_addr, old_size, new_size, flags)
} }
pub fn do_mprotect(addr: usize, size: usize, perms: VMPerms) -> Result<()> {
debug!(
"mprotect: addr: {:#x}, size: {:#x}, perms: {:?}",
addr, size, perms
);
let current = current!();
let mut current_vm = current.vm().lock().unwrap();
current_vm.mprotect(addr, size, perms)
}
pub fn do_brk(addr: usize) -> Result<usize> { pub fn do_brk(addr: usize) -> Result<usize> {
debug!("brk: addr: {:#x}", addr); debug!("brk: addr: {:#x}", addr);
let current = current!(); let current = current!();

@ -6,6 +6,7 @@ use super::user_space_vm::{UserSpaceVMManager, UserSpaceVMRange, USER_SPACE_VM_M
use super::vm_manager::{ use super::vm_manager::{
VMInitializer, VMManager, VMMapAddr, VMMapOptions, VMMapOptionsBuilder, VMRemapOptions, VMInitializer, VMManager, VMMapAddr, VMMapOptions, VMMapOptionsBuilder, VMRemapOptions,
}; };
use super::vm_perms::VMPerms;
#[derive(Debug)] #[derive(Debug)]
pub struct ProcessVMBuilder<'a, 'b> { pub struct ProcessVMBuilder<'a, 'b> {
@ -202,12 +203,18 @@ impl<'a, 'b> ProcessVMBuilder<'a, 'b> {
/// The per-process virtual memory /// The per-process virtual memory
#[derive(Debug)] #[derive(Debug)]
pub struct ProcessVM { pub struct ProcessVM {
process_range: UserSpaceVMRange, mmap_manager: VMManager,
elf_ranges: Vec<VMRange>, elf_ranges: Vec<VMRange>,
heap_range: VMRange, heap_range: VMRange,
stack_range: VMRange, stack_range: VMRange,
brk: usize, brk: usize,
mmap_manager: VMManager, // Memory safety notes: the process_range field must be the last one.
//
// Rust drops fields in the same order as they are declared. So by making
// process_range the last field, we ensure that when all other fields are
// dropped, their drop methods (if provided) can still access the memory
// region represented by the process_range field.
process_range: UserSpaceVMRange,
} }
impl Default for ProcessVM { impl Default for ProcessVM {
@ -313,6 +320,7 @@ impl ProcessVM {
let mmap_options = VMMapOptionsBuilder::default() let mmap_options = VMMapOptionsBuilder::default()
.size(size) .size(size)
.addr(addr_option) .addr(addr_option)
.perms(perms)
.initializer(initializer) .initializer(initializer)
.build()?; .build()?;
let mmap_addr = self.mmap_manager.mmap(&mmap_options)?; let mmap_addr = self.mmap_manager.mmap(&mmap_options)?;
@ -340,6 +348,20 @@ impl ProcessVM {
self.mmap_manager.munmap(addr, size) self.mmap_manager.munmap(addr, size)
} }
pub fn mprotect(&mut self, addr: usize, size: usize, perms: VMPerms) -> Result<()> {
let protect_range = VMRange::new_with_size(addr, size)?;
if !self.process_range.range().is_superset_of(&protect_range) {
return_errno!(ENOMEM, "invalid range");
}
// TODO: support mprotect vm regions in addition to mmap
if !self.mmap_manager.range().is_superset_of(&protect_range) {
warn!("Do not support mprotect memory outside the mmap region yet");
return Ok(());
}
self.mmap_manager.mprotect(addr, size, perms)
}
pub fn find_mmap_region(&self, addr: usize) -> Result<&VMRange> { pub fn find_mmap_region(&self, addr: usize) -> Result<&VMRange> {
self.mmap_manager.find_mmap_region(addr) self.mmap_manager.find_mmap_region(addr)
} }
@ -412,32 +434,6 @@ impl Default for MRemapFlags {
} }
} }
bitflags! {
pub struct VMPerms : u32 {
const READ = 0x1;
const WRITE = 0x2;
const EXEC = 0x4;
}
}
impl VMPerms {
pub fn can_read(&self) -> bool {
self.contains(VMPerms::READ)
}
pub fn can_write(&self) -> bool {
self.contains(VMPerms::WRITE)
}
pub fn can_execute(&self) -> bool {
self.contains(VMPerms::EXEC)
}
pub fn from_u32(bits: u32) -> Result<VMPerms> {
VMPerms::from_bits(bits).ok_or_else(|| errno!(EINVAL, "unknown permission bits"))
}
}
unsafe fn fill_zeros(addr: usize, size: usize) { unsafe fn fill_zeros(addr: usize, size: usize) {
let ptr = addr as *mut u8; let ptr = addr as *mut u8;
let buf = std::slice::from_raw_parts_mut(ptr, size); let buf = std::slice::from_raw_parts_mut(ptr, size);

@ -0,0 +1,51 @@
use std::ops::{Deref, DerefMut};
use super::vm_perms::VMPerms;
use super::vm_range::VMRange;
use super::*;
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub struct VMArea {
range: VMRange,
perms: VMPerms,
}
impl VMArea {
pub fn new(range: VMRange, perms: VMPerms) -> Self {
Self { range, perms }
}
pub fn perms(&self) -> VMPerms {
self.perms
}
pub fn range(&self) -> &VMRange {
&self.range
}
pub fn set_perms(&mut self, new_perms: VMPerms) {
self.perms = new_perms;
}
pub fn subtract(&self, other: &VMRange) -> Vec<VMArea> {
self.deref()
.subtract(other)
.iter()
.map(|range| VMArea::new(*range, self.perms()))
.collect()
}
}
impl Deref for VMArea {
type Target = VMRange;
fn deref(&self) -> &Self::Target {
&self.range
}
}
impl DerefMut for VMArea {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.range
}
}

@ -1,5 +1,8 @@
use super::*; use super::*;
use super::vm_area::VMArea;
use super::vm_perms::VMPerms;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum VMInitializer { pub enum VMInitializer {
DoNothing(), DoNothing(),
@ -66,6 +69,7 @@ impl Default for VMMapAddr {
pub struct VMMapOptions { pub struct VMMapOptions {
size: usize, size: usize,
align: usize, align: usize,
perms: VMPerms,
addr: VMMapAddr, addr: VMMapAddr,
initializer: VMInitializer, initializer: VMInitializer,
} }
@ -89,6 +93,9 @@ impl VMMapOptionsBuilder {
} }
align align
}; };
let perms = self
.perms
.ok_or_else(|| errno!(EINVAL, "perms must be given"))?;
let addr = { let addr = {
let addr = self.addr.unwrap_or_default(); let addr = self.addr.unwrap_or_default();
match addr { match addr {
@ -113,6 +120,7 @@ impl VMMapOptionsBuilder {
Ok(VMMapOptions { Ok(VMMapOptions {
size, size,
align, align,
perms,
addr, addr,
initializer, initializer,
}) })
@ -128,6 +136,10 @@ impl VMMapOptions {
&self.addr &self.addr
} }
pub fn perms(&self) -> &VMPerms {
&self.perms
}
pub fn initializer(&self) -> &VMInitializer { pub fn initializer(&self) -> &VMInitializer {
&self.initializer &self.initializer
} }
@ -201,52 +213,61 @@ impl VMRemapOptions {
/// ///
/// # Invariants /// # Invariants
/// ///
/// Behind the scene, VMManager maintains a list of VMRange that have been allocated. /// Behind the scene, VMManager maintains a list of VMArea that have been allocated.
/// (denoted as `self.sub_ranges`). To reason about the correctness of VMManager, we give /// (denoted as `self.vmas`). To reason about the correctness of VMManager, we give
/// the set of invariants hold by VMManager. /// the set of invariants hold by VMManager.
/// ///
/// 1. The rule of sentry: /// 1. The rule of sentry:
/// ``` /// ```
/// self.range.begin() == self.sub_ranges[0].start() == self.sub_ranges[0].end() /// self.range.begin() == self.vmas[0].start() == self.vmas[0].end()
/// ``` /// ```
/// and /// and
/// ``` /// ```
/// self.range.end() == self.sub_ranges[N-1].start() == self.sub_ranges[N-1].end() /// self.range.end() == self.vmas[N-1].start() == self.vmas[N-1].end()
/// ``` /// ```
/// where `N = self.sub_ranges.len()`. /// where `N = self.vmas.len()`.
/// ///
/// 2. The rule of non-emptyness: /// 2. The rule of non-emptyness:
/// ``` /// ```
/// self.sub_ranges[i].size() > 0, for 1 <= i < self.sub_ranges.len() - 1 /// self.vmas[i].size() > 0, for 1 <= i < self.vmas.len() - 1
/// ``` /// ```
/// ///
/// 3. The rule of ordering: /// 3. The rule of ordering:
/// ``` /// ```
/// self.sub_ranges[i].end() <= self.sub_ranges[i+1].start() for 0 <= i < self.sub_ranges.len() - 1 /// self.vmas[i].end() <= self.vmas[i+1].start() for 0 <= i < self.vmas.len() - 1
/// ``` /// ```
/// ///
/// 4. The rule of non-mergablility: /// 4. The rule of non-mergablility:
/// ``` /// ```
/// self.sub_ranges[i].end() != self.sub_ranges[i+1].start() for 1 <= i < self.sub_ranges.len() - 2 /// self.vmas[i].end() != self.vmas[i+1].start() || self.vmas[i].perms() != self.vmas[i+1].perms()
/// for 1 <= i < self.vmas.len() - 2
/// ``` /// ```
/// ///
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct VMManager { pub struct VMManager {
range: VMRange, range: VMRange,
sub_ranges: Vec<VMRange>, vmas: Vec<VMArea>,
} }
impl VMManager { impl VMManager {
pub fn from(addr: usize, size: usize) -> Result<VMManager> { pub fn from(addr: usize, size: usize) -> Result<VMManager> {
let range = VMRange::new(addr, addr + size)?; let range = VMRange::new(addr, addr + size)?;
let sub_ranges = { let vmas = {
let start = range.start(); let start = range.start();
let end = range.end(); let end = range.end();
let start_sentry = VMRange::new(start, start)?; let start_sentry = {
let end_sentry = VMRange::new(end, end)?; let range = VMRange::new_empty(start)?;
let perms = VMPerms::empty();
VMArea::new(range, perms)
};
let end_sentry = {
let range = VMRange::new_empty(end)?;
let perms = VMPerms::empty();
VMArea::new(range, perms)
};
vec![start_sentry, end_sentry] vec![start_sentry, end_sentry]
}; };
Ok(VMManager { range, sub_ranges }) Ok(VMManager { range, vmas })
} }
pub fn range(&self) -> &VMRange { pub fn range(&self) -> &VMRange {
@ -262,21 +283,24 @@ impl VMManager {
self.munmap(addr, size)?; self.munmap(addr, size)?;
} }
// Allocate a new subrange for this mmap request // Allocate a new range for this mmap request
let (insert_idx, free_subrange) = self.find_free_subrange(size, addr)?; let (insert_idx, free_range) = self.find_free_range(size, addr)?;
let new_subrange = self.alloc_subrange_from(size, addr, &free_subrange); let new_range = self.alloc_range_from(size, addr, &free_range);
let new_subrange_addr = new_subrange.start(); let new_addr = new_range.start();
let new_vma = VMArea::new(new_range, *options.perms());
// Initialize the memory of the new subrange // Initialize the memory of the new range
unsafe { unsafe {
let buf = new_subrange.as_slice_mut(); let buf = new_vma.as_slice_mut();
options.initializer.init_slice(buf)?; options.initializer.init_slice(buf)?;
} }
// Set memory permissions
Self::apply_perms(&new_vma, new_vma.perms());
// After initializing, we can safely insert the new subrange // After initializing, we can safely insert the new VMA
self.insert_new_subrange(insert_idx, new_subrange); self.insert_new_vma(insert_idx, new_vma);
Ok(new_subrange_addr) Ok(new_addr)
} }
pub fn munmap(&mut self, addr: usize, size: usize) -> Result<()> { pub fn munmap(&mut self, addr: usize, size: usize) -> Result<()> {
@ -301,24 +325,27 @@ impl VMManager {
effective_munmap_range effective_munmap_range
}; };
let new_sub_ranges = self let new_vmas = self
.sub_ranges .vmas
.iter() .iter()
.flat_map(|subrange| { .flat_map(|vma| {
// Keep the two sentry subranges intact // Keep the two sentry VMA intact
if subrange.size() == 0 { if vma.size() == 0 {
return vec![*subrange]; return vec![*vma];
} }
let unmapped_subrange = match subrange.intersect(&munmap_range) { let intersection_range = match vma.intersect(&munmap_range) {
None => return vec![*subrange], None => return vec![*vma],
Some(unmapped_subrange) => unmapped_subrange, Some(intersection_range) => intersection_range,
}; };
subrange.subtract(&unmapped_subrange) // Reset memory permissions
Self::apply_perms(&intersection_range, VMPerms::default());
vma.subtract(&intersection_range)
}) })
.collect(); .collect();
self.sub_ranges = new_sub_ranges; self.vmas = new_vmas;
Ok(()) Ok(())
} }
@ -343,9 +370,15 @@ impl VMManager {
SizeType::Growing SizeType::Growing
}; };
// The old range must not span over multiple sub-ranges // Get the memory permissions of the old range
self.find_containing_subrange_idx(&old_range) let perms = {
.ok_or_else(|| errno!(EFAULT, "invalid range"))?; // The old range must be contained in one VMA
let idx = self
.find_containing_vma_idx(&old_range)
.ok_or_else(|| errno!(EFAULT, "invalid range"))?;
let containing_vma = &self.vmas[idx];
containing_vma.perms()
};
// Implement mremap as one optional mmap followed by one optional munmap. // Implement mremap as one optional mmap followed by one optional munmap.
// //
@ -362,6 +395,7 @@ impl VMManager {
let mmap_opts = VMMapOptionsBuilder::default() let mmap_opts = VMMapOptionsBuilder::default()
.size(new_size - old_size) .size(new_size - old_size)
.addr(VMMapAddr::Need(old_range.end())) .addr(VMMapAddr::Need(old_range.end()))
.perms(perms)
.initializer(VMInitializer::FillZeros()) .initializer(VMInitializer::FillZeros())
.build()?; .build()?;
let ret_addr = Some(old_addr); let ret_addr = Some(old_addr);
@ -374,6 +408,7 @@ impl VMManager {
let mmap_ops = VMMapOptionsBuilder::default() let mmap_ops = VMMapOptionsBuilder::default()
.size(prefered_new_range.size()) .size(prefered_new_range.size())
.addr(VMMapAddr::Need(prefered_new_range.start())) .addr(VMMapAddr::Need(prefered_new_range.start()))
.perms(perms)
.initializer(VMInitializer::FillZeros()) .initializer(VMInitializer::FillZeros())
.build()?; .build()?;
(Some(mmap_ops), Some(old_addr)) (Some(mmap_ops), Some(old_addr))
@ -381,6 +416,7 @@ impl VMManager {
let mmap_ops = VMMapOptionsBuilder::default() let mmap_ops = VMMapOptionsBuilder::default()
.size(new_size) .size(new_size)
.addr(VMMapAddr::Any) .addr(VMMapAddr::Any)
.perms(perms)
.initializer(VMInitializer::CopyFrom { range: old_range }) .initializer(VMInitializer::CopyFrom { range: old_range })
.build()?; .build()?;
// Cannot determine the returned address for now, which can only be obtained after calling mmap // Cannot determine the returned address for now, which can only be obtained after calling mmap
@ -392,6 +428,7 @@ impl VMManager {
let mmap_opts = VMMapOptionsBuilder::default() let mmap_opts = VMMapOptionsBuilder::default()
.size(new_size) .size(new_size)
.addr(VMMapAddr::Force(new_addr)) .addr(VMMapAddr::Force(new_addr))
.perms(perms)
.initializer(VMInitializer::CopyFrom { range: old_range }) .initializer(VMInitializer::CopyFrom { range: old_range })
.build()?; .build()?;
let ret_addr = Some(new_addr); let ret_addr = Some(new_addr);
@ -442,41 +479,109 @@ impl VMManager {
Ok(ret_addr.unwrap()) Ok(ret_addr.unwrap())
} }
pub fn mprotect(&mut self, addr: usize, size: usize, new_perms: VMPerms) -> Result<()> {
let protect_range = VMRange::new_with_size(addr, size)?;
// FIXME: the current implementation requires the target range to be
// contained in exact one VMA.
let containing_idx = self
.find_containing_vma_idx(&protect_range)
.ok_or_else(|| errno!(ENOMEM, "invalid range"))?;
let containing_vma = &self.vmas[containing_idx];
let old_perms = containing_vma.perms();
if new_perms == old_perms {
return Ok(());
}
let same_start = protect_range.start() == containing_vma.start();
let same_end = protect_range.end() == containing_vma.end();
let containing_vma = &mut self.vmas[containing_idx];
match (same_start, same_end) {
(true, true) => {
containing_vma.set_perms(new_perms);
Self::apply_perms(containing_vma, containing_vma.perms());
}
(false, true) => {
containing_vma.set_end(protect_range.start());
let new_vma = VMArea::new(protect_range, new_perms);
Self::apply_perms(&new_vma, new_vma.perms());
self.insert_new_vma(containing_idx + 1, new_vma);
}
(true, false) => {
containing_vma.set_start(protect_range.end());
let new_vma = VMArea::new(protect_range, new_perms);
Self::apply_perms(&new_vma, new_vma.perms());
self.insert_new_vma(containing_idx, new_vma);
}
(false, false) => {
// The containing VMA is divided into three VMAs:
// Shrinked old VMA: [containing_vma.start, protect_range.start)
// New VMA: [protect_range.start, protect_range.end)
// Another new vma: [protect_range.end, containing_vma.end)
let old_end = containing_vma.end();
let protect_end = protect_range.end();
// Shrinked old VMA
containing_vma.set_end(protect_range.start());
// New VMA
let new_vma = VMArea::new(protect_range, new_perms);
Self::apply_perms(&new_vma, new_vma.perms());
self.insert_new_vma(containing_idx + 1, new_vma);
// Another new VMA
let new_vma2 = {
let range = VMRange::new(protect_end, old_end).unwrap();
VMArea::new(range, old_perms)
};
self.insert_new_vma(containing_idx + 2, new_vma2);
}
}
Ok(())
}
pub fn find_mmap_region(&self, addr: usize) -> Result<&VMRange> { pub fn find_mmap_region(&self, addr: usize) -> Result<&VMRange> {
self.sub_ranges self.vmas
.iter() .iter()
.find(|subrange| subrange.contains(addr)) .map(|vma| vma.range())
.find(|vma| vma.contains(addr))
.ok_or_else(|| errno!(ESRCH, "no mmap regions that contains the address")) .ok_or_else(|| errno!(ESRCH, "no mmap regions that contains the address"))
} }
// Find a subrange that contains the given range and returns the index of the subrange // Find a VMA that contains the given range, returning the VMA's index
fn find_containing_subrange_idx(&self, target_range: &VMRange) -> Option<usize> { fn find_containing_vma_idx(&self, target_range: &VMRange) -> Option<usize> {
self.sub_ranges self.vmas
.iter() .iter()
.position(|subrange| subrange.is_superset_of(target_range)) .position(|vma| vma.is_superset_of(target_range))
} }
// Returns whether the requested range is free // Returns whether the requested range is free
fn is_free_range(&self, request_range: &VMRange) -> bool { fn is_free_range(&self, request_range: &VMRange) -> bool {
self.range.is_superset_of(request_range) self.range.is_superset_of(request_range)
&& self && self
.sub_ranges .vmas
.iter() .iter()
.all(|range| range.overlap_with(request_range) == false) .all(|range| range.overlap_with(request_range) == false)
} }
// Find the free subrange that satisfies the constraints of size and address // Find the free range that satisfies the constraints of size and address
fn find_free_subrange(&self, size: usize, addr: VMMapAddr) -> Result<(usize, VMRange)> { fn find_free_range(&self, size: usize, addr: VMMapAddr) -> Result<(usize, VMRange)> {
// TODO: reduce the complexity from O(N) to O(log(N)), where N is // TODO: reduce the complexity from O(N) to O(log(N)), where N is
// the number of existing subranges. // the number of existing VMAs.
// Record the minimal free range that satisfies the contraints // Record the minimal free range that satisfies the contraints
let mut result_free_range: Option<VMRange> = None; let mut result_free_range: Option<VMRange> = None;
let mut result_idx: Option<usize> = None; let mut result_idx: Option<usize> = None;
for (idx, range_pair) in self.sub_ranges.windows(2).enumerate() { for (idx, range_pair) in self.vmas.windows(2).enumerate() {
// Since we have two sentry sub_ranges at both ends, we can be sure that the free // Since we have two sentry vmas at both ends, we can be sure that the free
// space only appears between two consecutive sub_ranges. // space only appears between two consecutive vmas.
let pre_range = &range_pair[0]; let pre_range = &range_pair[0];
let next_range = &range_pair[1]; let next_range = &range_pair[1];
@ -539,68 +644,101 @@ impl VMManager {
Ok((insert_idx, free_range)) Ok((insert_idx, free_range))
} }
fn alloc_subrange_from( fn alloc_range_from(&self, size: usize, addr: VMMapAddr, free_range: &VMRange) -> VMRange {
&self, debug_assert!(free_range.size() >= size);
size: usize,
addr: VMMapAddr,
free_subrange: &VMRange,
) -> VMRange {
debug_assert!(free_subrange.size() >= size);
let mut new_subrange = *free_subrange; let mut new_range = *free_range;
if let VMMapAddr::Need(addr) = addr { if let VMMapAddr::Need(addr) = addr {
debug_assert!(addr == new_subrange.start()); debug_assert!(addr == new_range.start());
} }
if let VMMapAddr::Force(addr) = addr { if let VMMapAddr::Force(addr) = addr {
debug_assert!(addr == new_subrange.start()); debug_assert!(addr == new_range.start());
} }
new_subrange.resize(size); new_range.resize(size);
new_subrange new_range
} }
// Insert the new sub-range, and when possible, merge it with its neighbors. // Insert a new VMA, and when possible, merge it with its neighbors.
fn insert_new_subrange(&mut self, insert_idx: usize, new_subrange: VMRange) { fn insert_new_vma(&mut self, insert_idx: usize, new_vma: VMArea) {
// New sub-range can only be inserted between the two sentry sub-ranges // New VMA can only be inserted between the two sentry VMAs
debug_assert!(0 < insert_idx && insert_idx < self.sub_ranges.len()); debug_assert!(0 < insert_idx && insert_idx < self.vmas.len());
let left_idx = insert_idx - 1; let left_idx = insert_idx - 1;
let right_idx = insert_idx; let right_idx = insert_idx;
// Double check the order let left_vma = &self.vmas[left_idx];
debug_assert!(self.sub_ranges[left_idx].end() <= new_subrange.start()); let right_vma = &self.vmas[right_idx];
debug_assert!(new_subrange.end() <= self.sub_ranges[right_idx].start());
let left_mergable = if left_idx > 0 { // Double check the order
// Mergable if there is no gap between the left neighbor and the new sub-range debug_assert!(left_vma.end() <= new_vma.start());
self.sub_ranges[left_idx].end() == new_subrange.start() debug_assert!(new_vma.end() <= right_vma.start());
} else {
// The left sentry sub-range is NOT mergable with any sub-range let left_mergable = Self::can_merge_vmas(left_vma, &new_vma);
false let right_mergable = Self::can_merge_vmas(&new_vma, right_vma);
};
let right_mergable = if right_idx < self.sub_ranges.len() - 1 { drop(left_vma);
// Mergable if there is no gap between the right neighbor and the new sub-range drop(right_vma);
self.sub_ranges[right_idx].start() == new_subrange.end()
} else {
// The right sentry sub-range is NOT mergable with any sub-range
false
};
match (left_mergable, right_mergable) { match (left_mergable, right_mergable) {
(false, false) => { (false, false) => {
self.sub_ranges.insert(insert_idx, new_subrange); self.vmas.insert(insert_idx, new_vma);
} }
(true, false) => { (true, false) => {
self.sub_ranges[left_idx].end = new_subrange.end; self.vmas[left_idx].set_end(new_vma.end);
} }
(false, true) => { (false, true) => {
self.sub_ranges[right_idx].start = new_subrange.start; self.vmas[right_idx].set_start(new_vma.start);
} }
(true, true) => { (true, true) => {
self.sub_ranges[left_idx].end = self.sub_ranges[right_idx].end; let left_new_end = self.vmas[right_idx].end();
self.sub_ranges.remove(right_idx); self.vmas[left_idx].set_end(left_new_end);
self.vmas.remove(right_idx);
} }
} }
} }
fn can_merge_vmas(left: &VMArea, right: &VMArea) -> bool {
debug_assert!(left.end() <= right.start());
// Both of the two VMAs are not sentry (whose size == 0)
left.size() > 0 && right.size() > 0 &&
// Two VMAs must border with each other
left.end() == right.start() &&
// Two VMAs must have the same memory permissions
left.perms() == right.perms()
}
fn apply_perms(protect_range: &VMRange, perms: VMPerms) {
extern "C" {
pub fn occlum_ocall_mprotect(
retval: *mut i32,
addr: *const c_void,
len: usize,
prot: i32,
) -> sgx_status_t;
};
unsafe {
let mut retval = 0;
let addr = protect_range.start() as *const c_void;
let len = protect_range.size();
let prot = perms.bits() as i32;
let sgx_status = occlum_ocall_mprotect(&mut retval, addr, len, prot);
assert!(sgx_status == sgx_status_t::SGX_SUCCESS && retval == 0);
}
}
}
impl Drop for VMManager {
fn drop(&mut self) {
// Ensure that memory permissions are recovered
for vma in &self.vmas {
if vma.size() == 0 || vma.perms() == VMPerms::default() {
continue;
}
Self::apply_perms(vma, VMPerms::default());
}
}
} }

@ -0,0 +1,34 @@
use super::*;
bitflags! {
pub struct VMPerms : u32 {
const READ = 0x1;
const WRITE = 0x2;
const EXEC = 0x4;
const ALL = Self::READ.bits | Self::WRITE.bits | Self::EXEC.bits;
}
}
impl VMPerms {
pub fn from_u32(bits: u32) -> Result<VMPerms> {
Self::from_bits(bits).ok_or_else(|| errno!(EINVAL, "invalid bits"))
}
pub fn can_read(&self) -> bool {
self.contains(VMPerms::READ)
}
pub fn can_write(&self) -> bool {
self.contains(VMPerms::WRITE)
}
pub fn can_execute(&self) -> bool {
self.contains(VMPerms::EXEC)
}
}
impl Default for VMPerms {
fn default() -> Self {
VMPerms::ALL
}
}

@ -60,9 +60,20 @@ impl VMRange {
} }
pub fn resize(&mut self, new_size: usize) { pub fn resize(&mut self, new_size: usize) {
debug_assert!(new_size % PAGE_SIZE == 0);
self.end = self.start + new_size; self.end = self.start + new_size;
} }
pub fn set_start(&mut self, start: usize) {
debug_assert!(start % PAGE_SIZE == 0 && start <= self.end);
self.start = start;
}
pub fn set_end(&mut self, end: usize) {
debug_assert!(end % PAGE_SIZE == 0 && end >= self.start);
self.end = end;
}
pub fn empty(&self) -> bool { pub fn empty(&self) -> bool {
self.start == self.end self.start == self.end
} }
@ -75,36 +86,47 @@ impl VMRange {
self.start() <= addr && addr < self.end() self.start() <= addr && addr < self.end()
} }
// Returns whether two ranges have non-empty interesection.
pub fn overlap_with(&self, other: &VMRange) -> bool { pub fn overlap_with(&self, other: &VMRange) -> bool {
self.start() < other.end() && other.start() < self.end() let intersection_start = self.start().max(other.start());
let intersection_end = self.end().min(other.end());
intersection_start < intersection_end
} }
// Returns a set of ranges by subtracting self with the other.
//
// Post-condition: the returned ranges have non-zero sizes.
pub fn subtract(&self, other: &VMRange) -> Vec<VMRange> { pub fn subtract(&self, other: &VMRange) -> Vec<VMRange> {
if self.size() == 0 {
return vec![];
}
let intersection = match self.intersect(other) {
None => return vec![*self],
Some(intersection) => intersection,
};
let self_start = self.start(); let self_start = self.start();
let self_end = self.end(); let self_end = self.end();
let other_start = other.start(); let inter_start = intersection.start();
let other_end = other.end(); let inter_end = intersection.end();
debug_assert!(self_start <= inter_start);
debug_assert!(inter_end <= self_end);
match (self_start < other_start, other_end < self_end) { match (self_start < inter_start, inter_end < self_end) {
(false, false) => Vec::new(), (false, false) => Vec::new(),
(false, true) => unsafe { (false, true) => unsafe { vec![VMRange::from_unchecked(inter_end, self_end)] },
vec![VMRange::from_unchecked(self_start.max(other_end), self_end)] (true, false) => unsafe { vec![VMRange::from_unchecked(self_start, inter_start)] },
},
(true, false) => unsafe {
vec![VMRange::from_unchecked(
self_start,
self_end.min(other_start),
)]
},
(true, true) => unsafe { (true, true) => unsafe {
vec![ vec![
VMRange::from_unchecked(self_start, other_start), VMRange::from_unchecked(self_start, inter_start),
VMRange::from_unchecked(other_end, self_end), VMRange::from_unchecked(inter_end, self_end),
] ]
}, },
} }
} }
// Returns an non-empty intersection if where is any
pub fn intersect(&self, other: &VMRange) -> Option<VMRange> { pub fn intersect(&self, other: &VMRange) -> Option<VMRange> {
let intersection_start = self.start().max(other.start()); let intersection_start = self.start().max(other.start());
let intersection_end = self.end().min(other.end()); let intersection_end = self.end().min(other.end());

@ -1,4 +1,5 @@
#include <stdlib.h> #include <stdlib.h>
#include <sys/mman.h>
#include "ocalls.h" #include "ocalls.h"
void *occlum_ocall_posix_memalign(size_t alignment, size_t size) { void *occlum_ocall_posix_memalign(size_t alignment, size_t size) {
@ -25,3 +26,7 @@ void *occlum_ocall_posix_memalign(size_t alignment, size_t size) {
void occlum_ocall_free(void *ptr) { void occlum_ocall_free(void *ptr) {
free(ptr); free(ptr);
} }
int occlum_ocall_mprotect(void *addr, size_t len, int prot) {
return mprotect(addr, len, prot);
}

@ -733,6 +733,227 @@ int test_mremap_with_fixed_addr() {
return 0; return 0;
} }
// ============================================================================
// Test cases for mprotect
// ============================================================================
int test_mprotect_once() {
// The memory permissions initially looks like below:
//
// Pages: #0 #1 #2 #3
// -------------------------------------
// Memory perms: [ ][ ][ ][ ]
size_t total_len = 4; // in pages
int init_prot = PROT_NONE;
// The four settings for mprotect and its resulting memory perms.
//
// Pages: #0 #1 #2 #3
// -------------------------------------
// Setting (i = 0):
// mprotect: [RW ][RW ][RW ][RW ]
// result: [RW ][RW ][RW ][RW ]
// Setting (i = 1):
// mprotect: [RW ]
// result: [RW ][ ][ ][ ]
// Setting (i = 2):
// mprotect: [RW ][RW ]
// result: [ ][ ][RW ][RW ]
// Setting (i = 3):
// mprotect: [RW ][RW ]
// result: [ ][RW ][RW ][ ]
size_t lens[] = { 4, 1, 2, 2}; // in pages
size_t offsets[] = { 0, 0, 2, 1}; // in pages
for (int i = 0; i < ARRAY_SIZE(lens); i++) {
int flags = MAP_PRIVATE | MAP_ANONYMOUS;
void *buf = mmap(NULL, total_len * PAGE_SIZE, init_prot, flags, -1, 0);
if (buf == MAP_FAILED) {
THROW_ERROR("mmap failed");
}
size_t len = lens[i] * PAGE_SIZE;
size_t offset = offsets[i] * PAGE_SIZE;
int prot = PROT_READ | PROT_WRITE;
void *tmp_buf = (char *)buf + offset;
int ret = mprotect(tmp_buf, len, prot);
if (ret < 0) {
THROW_ERROR("mprotect failed");
}
ret = munmap(buf, total_len * PAGE_SIZE);
if (ret < 0) {
THROW_ERROR("munmap failed");
}
}
return 0;
}
int test_mprotect_twice() {
// The memory permissions initially looks like below:
//
// Pages: #0 #1 #2 #3
// -------------------------------------
// Memory perms: [ ][ ][ ][ ]
size_t total_len = 4; // in pages
int init_prot = PROT_NONE;
// The four settings for mprotects and their results
//
// Pages: #0 #1 #2 #3
// -------------------------------------
// Setting (i = 0):
// mprotect (j = 0): [RW ][RW ]
// mprotect (j = 1): [RW ][RW ]
// result: [RW ][RW ][RW ][RW ]
// Setting (i = 1):
// mprotect (j = 0): [RW ]
// mprotect (j = 1): [RW ]
// result: [ ][RW ][ ][RW ]
// Setting (i = 2):
// mprotect (j = 0): [RW ][RW ]
// mprotect (j = 1): [ WX][ WX]
// result: [ ][ WX][ WX][ ]
// Setting (i = 3):
// mprotect (j = 0): [RW ][RW ]
// mprotect (j = 1): [ ]
// result: [ ][ ][RW ][ ]
size_t lens[][2] = {
{ 2, 2 },
{ 1, 1 },
{ 2, 2 },
{ 2, 1 }
}; // in pages
size_t offsets[][2] = {
{ 0, 2 },
{ 1, 3 },
{ 1, 1 },
{ 1, 1 }
}; // in pages
int prots[][2] = {
{ PROT_READ | PROT_WRITE, PROT_READ | PROT_WRITE },
{ PROT_READ | PROT_WRITE, PROT_READ | PROT_WRITE },
{ PROT_READ | PROT_WRITE, PROT_WRITE | PROT_EXEC },
{ PROT_READ | PROT_WRITE, PROT_NONE }
};
for (int i = 0; i < ARRAY_SIZE(lens); i++) {
int flags = MAP_PRIVATE | MAP_ANONYMOUS;
void *buf = mmap(NULL, total_len * PAGE_SIZE, init_prot, flags, -1, 0);
if (buf == MAP_FAILED) {
THROW_ERROR("mmap failed");
}
for (int j = 0; j < 2; j++) {
size_t len = lens[i][j] * PAGE_SIZE;
size_t offset = offsets[i][j] * PAGE_SIZE;
int prot = prots[i][j];
void *tmp_buf = (char *)buf + offset;
int ret = mprotect(tmp_buf, len, prot);
if (ret < 0) {
THROW_ERROR("mprotect failed");
}
}
int ret = munmap(buf, total_len * PAGE_SIZE);
if (ret < 0) {
THROW_ERROR("munmap failed");
}
}
return 0;
}
int test_mprotect_triple() {
// The memory permissions initially looks like below:
//
// Pages: #0 #1 #2 #3
// -------------------------------------
// Memory perms: [RWX][RWX][RWX][RWX]
size_t total_len = 4; // in pages
int init_prot = PROT_READ | PROT_WRITE | PROT_EXEC;
// The four settings for mprotects and their results
//
// Pages: #0 #1 #2 #3
// -------------------------------------
// Setting (i = 0):
// mprotect (j = 0): [ ][ ]
// mprotect (j = 1): [ ]
// mprotect (j = 2): [ ]
// result: [ ][ ][ ][ ]
size_t lens[][3] = {
{ 2, 1, 1 },
}; // in pages
size_t offsets[][3] = {
{ 0, 3, 2 },
}; // in pages
int prots[][3] = {
{ PROT_NONE, PROT_NONE, PROT_NONE },
};
for (int i = 0; i < ARRAY_SIZE(lens); i++) {
int flags = MAP_PRIVATE | MAP_ANONYMOUS;
void *buf = mmap(NULL, total_len * PAGE_SIZE, init_prot, flags, -1, 0);
if (buf == MAP_FAILED) {
THROW_ERROR("mmap failed");
}
for (int j = 0; j < 3; j++) {
size_t len = lens[i][j] * PAGE_SIZE;
size_t offset = offsets[i][j] * PAGE_SIZE;
int prot = prots[i][j];
void *tmp_buf = (char *)buf + offset;
int ret = mprotect(tmp_buf, len, prot);
if (ret < 0) {
THROW_ERROR("mprotect failed");
}
}
int ret = munmap(buf, total_len * PAGE_SIZE);
if (ret < 0) {
THROW_ERROR("munmap failed");
}
}
return 0;
}
int test_mprotect_with_zero_len() {
int flags = MAP_PRIVATE | MAP_ANONYMOUS;
void *buf = mmap(NULL, PAGE_SIZE, PROT_NONE, flags, -1, 0);
if (buf == MAP_FAILED) {
THROW_ERROR("mmap failed");
}
int ret = mprotect(buf, 0, PROT_NONE);
if (ret < 0) {
THROW_ERROR("mprotect failed");
}
ret = munmap(buf, PAGE_SIZE);
if (ret < 0) {
THROW_ERROR("munmap failed");
}
return 0;
}
int test_mprotect_with_invalid_addr() {
int ret = mprotect(NULL, PAGE_SIZE, PROT_NONE);
if (ret == 0 || errno != ENOMEM) {
THROW_ERROR("using invalid addr should have failed");
}
return 0;
}
int test_mprotect_with_invalid_prot() {
int invalid_prot = 0x1234; // invalid protection bits
void *valid_addr = &invalid_prot;
size_t valid_len = PAGE_SIZE;
int ret = mprotect(valid_addr, valid_len, invalid_prot);
if (ret == 0 || errno != EINVAL) {
THROW_ERROR("using invalid addr should have failed");
}
return 0;
}
// ============================================================================ // ============================================================================
// Test suite main // Test suite main
// ============================================================================ // ============================================================================
@ -762,6 +983,12 @@ static test_case_t test_cases[] = {
TEST_CASE(test_mremap), TEST_CASE(test_mremap),
TEST_CASE(test_mremap_subrange), TEST_CASE(test_mremap_subrange),
TEST_CASE(test_mremap_with_fixed_addr), TEST_CASE(test_mremap_with_fixed_addr),
TEST_CASE(test_mprotect_once),
TEST_CASE(test_mprotect_twice),
TEST_CASE(test_mprotect_triple),
TEST_CASE(test_mprotect_with_zero_len),
TEST_CASE(test_mprotect_with_invalid_addr),
TEST_CASE(test_mprotect_with_invalid_prot),
}; };
int main() { int main() {