Optimize ProcessVM to be interiorly mutable

This commit is contained in:
He Sun 2020-06-17 18:23:46 +08:00
parent bca0663972
commit f854950416
7 changed files with 41 additions and 59 deletions

@ -34,7 +34,6 @@ pub fn do_clone(
let current = current!(); let current = current!();
let vm = current.vm().clone(); let vm = current.vm().clone();
let task = { let task = {
let vm = vm.lock().unwrap();
let user_stack_range = guess_user_stack_bound(&vm, user_rsp)?; let user_stack_range = guess_user_stack_bound(&vm, user_rsp)?;
let user_stack_base = user_stack_range.end(); let user_stack_base = user_stack_range.end();
let user_stack_limit = user_stack_range.start(); let user_stack_limit = user_stack_range.start();
@ -230,16 +229,16 @@ fn check_clone_flags(flags: CloneFlags) -> Result<()> {
Ok(()) Ok(())
} }
fn guess_user_stack_bound(vm: &ProcessVM, user_rsp: usize) -> Result<&VMRange> { fn guess_user_stack_bound(vm: &ProcessVM, user_rsp: usize) -> Result<VMRange> {
// The first case is most likely // The first case is most likely
if let Ok(stack_range) = vm.find_mmap_region(user_rsp) { if let Ok(stack_range) = vm.find_mmap_region(user_rsp) {
Ok(stack_range) Ok(stack_range)
} }
// The next three cases are very unlikely, but valid // The next three cases are very unlikely, but valid
else if vm.get_stack_range().contains(user_rsp) { else if vm.get_stack_range().contains(user_rsp) {
Ok(vm.get_stack_range()) Ok(*vm.get_stack_range())
} else if vm.get_heap_range().contains(user_rsp) { } else if vm.get_heap_range().contains(user_rsp) {
Ok(vm.get_heap_range()) Ok(*vm.get_heap_range())
} }
// Invalid // Invalid
else { else {

@ -156,7 +156,7 @@ fn new_process(
)? )?
} }
}; };
let vm_ref = Arc::new(SgxMutex::new(vm)); let vm_ref = Arc::new(vm);
let files_ref = { let files_ref = {
let files = init_files(current_ref, file_actions, host_stdio_fds)?; let files = init_files(current_ref, file_actions, host_stdio_fds)?;
Arc::new(SgxMutex::new(files)) Arc::new(SgxMutex::new(files))

@ -59,7 +59,7 @@ pub type uid_t = u32;
pub type ProcessRef = Arc<Process>; pub type ProcessRef = Arc<Process>;
pub type ThreadRef = Arc<Thread>; pub type ThreadRef = Arc<Thread>;
pub type FileTableRef = Arc<SgxMutex<FileTable>>; pub type FileTableRef = Arc<SgxMutex<FileTable>>;
pub type ProcessVMRef = Arc<SgxMutex<ProcessVM>>; pub type ProcessVMRef = Arc<ProcessVM>;
pub type FsViewRef = Arc<SgxMutex<FsView>>; pub type FsViewRef = Arc<SgxMutex<FsView>>;
pub type SchedAgentRef = Arc<SgxMutex<SchedAgent>>; pub type SchedAgentRef = Arc<SgxMutex<SchedAgent>>;
pub type ResourceLimitsRef = Arc<SgxMutex<ResourceLimits>>; pub type ResourceLimitsRef = Arc<SgxMutex<ResourceLimits>>;

@ -16,7 +16,7 @@ lazy_static! {
fn create_idle_thread() -> Result<ThreadRef> { fn create_idle_thread() -> Result<ThreadRef> {
// Create dummy values for the mandatory fields // Create dummy values for the mandatory fields
let dummy_tid = ThreadId::zero(); let dummy_tid = ThreadId::zero();
let dummy_vm = Arc::new(SgxMutex::new(ProcessVM::default())); let dummy_vm = Arc::new(ProcessVM::default());
let dummy_task = Task::default(); let dummy_task = Task::default();
// Assemble the idle process // Assemble the idle process

@ -87,8 +87,7 @@ pub mod from_user {
/// len: the length in byte /// len: the length in byte
fn is_inside_user_space(addr: *const u8, len: usize) -> bool { fn is_inside_user_space(addr: *const u8, len: usize) -> bool {
let current = current!(); let current = current!();
let current_vm = current.vm().lock().unwrap(); let user_range = current.vm().get_process_range();
let user_range = current_vm.get_process_range();
let ur_start = user_range.start(); let ur_start = user_range.start();
let ur_end = user_range.end(); let ur_end = user_range.end();
let addr_start = addr as usize; let addr_start = addr as usize;

@ -39,16 +39,13 @@ pub fn do_mmap(
); );
} }
let current = current!(); current!().vm().mmap(addr, size, perms, flags, fd, offset)
let mut current_vm = current.vm().lock().unwrap();
current_vm.mmap(addr, size, perms, flags, fd, offset)
} }
pub fn do_munmap(addr: usize, size: usize) -> Result<()> { pub fn do_munmap(addr: usize, size: usize) -> Result<()> {
debug!("munmap: addr: {:#x}, size: {:#x}", addr, size); debug!("munmap: addr: {:#x}, size: {:#x}", addr, size);
let current = current!(); let current = current!();
let mut current_vm = current.vm().lock().unwrap(); current!().vm().munmap(addr, size)
current_vm.munmap(addr, size)
} }
pub fn do_mremap( pub fn do_mremap(
@ -61,9 +58,7 @@ pub fn do_mremap(
"mremap: old_addr: {:#x}, old_size: {:#x}, new_size: {:#x}, flags: {:?}", "mremap: old_addr: {:#x}, old_size: {:#x}, new_size: {:#x}, flags: {:?}",
old_addr, old_size, new_size, flags old_addr, old_size, new_size, flags
); );
let current = current!(); current!().vm().mremap(old_addr, old_size, new_size, flags)
let mut current_vm = current.vm().lock().unwrap();
current_vm.mremap(old_addr, old_size, new_size, flags)
} }
pub fn do_mprotect(addr: usize, size: usize, perms: VMPerms) -> Result<()> { pub fn do_mprotect(addr: usize, size: usize, perms: VMPerms) -> Result<()> {
@ -71,16 +66,12 @@ pub fn do_mprotect(addr: usize, size: usize, perms: VMPerms) -> Result<()> {
"mprotect: addr: {:#x}, size: {:#x}, perms: {:?}", "mprotect: addr: {:#x}, size: {:#x}, perms: {:?}",
addr, size, perms addr, size, perms
); );
let current = current!(); current!().vm().mprotect(addr, size, perms)
let mut current_vm = current.vm().lock().unwrap();
current_vm.mprotect(addr, size, perms)
} }
pub fn do_brk(addr: usize) -> Result<usize> { pub fn do_brk(addr: usize) -> Result<usize> {
debug!("brk: addr: {:#x}", addr); debug!("brk: addr: {:#x}", addr);
let current = current!(); current!().vm().brk(addr)
let mut current_vm = current.vm().lock().unwrap();
current_vm.brk(addr)
} }
pub const PAGE_SIZE: usize = 4096; pub const PAGE_SIZE: usize = 4096;

@ -7,6 +7,7 @@ use super::vm_manager::{
VMInitializer, VMManager, VMMapAddr, VMMapOptions, VMMapOptionsBuilder, VMRemapOptions, VMInitializer, VMManager, VMMapAddr, VMMapOptions, VMMapOptionsBuilder, VMRemapOptions,
}; };
use super::vm_perms::VMPerms; use super::vm_perms::VMPerms;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)] #[derive(Debug)]
pub struct ProcessVMBuilder<'a, 'b> { pub struct ProcessVMBuilder<'a, 'b> {
@ -118,13 +119,7 @@ impl<'a, 'b> ProcessVMBuilder<'a, 'b> {
last_elf_range.end() last_elf_range.end()
}; };
let heap_range = VMRange::new_with_layout(heap_layout, heap_min_start); let heap_range = VMRange::new_with_layout(heap_layout, heap_min_start);
unsafe { let brk = AtomicUsize::new(heap_range.start());
let heap_buf = heap_range.as_slice_mut();
for b in heap_buf {
*b = 0;
}
}
let brk = heap_range.start();
// Init the stack memory in the process // Init the stack memory in the process
let stack_layout = &other_layouts[1]; let stack_layout = &other_layouts[1];
@ -147,6 +142,8 @@ impl<'a, 'b> ProcessVMBuilder<'a, 'b> {
debug_assert!(process_range.range().is_superset_of(&stack_range)); debug_assert!(process_range.range().is_superset_of(&stack_range));
debug_assert!(process_range.range().is_superset_of(&mmap_range)); debug_assert!(process_range.range().is_superset_of(&mmap_range));
let mmap_manager = SgxMutex::new(mmap_manager);
Ok(ProcessVM { Ok(ProcessVM {
process_range, process_range,
elf_ranges, elf_ranges,
@ -203,11 +200,11 @@ impl<'a, 'b> ProcessVMBuilder<'a, 'b> {
/// The per-process virtual memory /// The per-process virtual memory
#[derive(Debug)] #[derive(Debug)]
pub struct ProcessVM { pub struct ProcessVM {
mmap_manager: VMManager, mmap_manager: SgxMutex<VMManager>,
elf_ranges: Vec<VMRange>, elf_ranges: Vec<VMRange>,
heap_range: VMRange, heap_range: VMRange,
stack_range: VMRange, stack_range: VMRange,
brk: usize, brk: AtomicUsize,
// Memory safety notes: the process_range field must be the last one. // Memory safety notes: the process_range field must be the last one.
// //
// Rust drops fields in the same order as they are declared. So by making // Rust drops fields in the same order as they are declared. So by making
@ -260,10 +257,10 @@ impl ProcessVM {
} }
pub fn get_brk(&self) -> usize { pub fn get_brk(&self) -> usize {
self.brk self.brk.load(Ordering::SeqCst)
} }
pub fn brk(&mut self, new_brk: usize) -> Result<usize> { pub fn brk(&self, new_brk: usize) -> Result<usize> {
let heap_start = self.heap_range.start(); let heap_start = self.heap_range.start();
let heap_end = self.heap_range.end(); let heap_end = self.heap_range.end();
@ -275,16 +272,13 @@ impl ProcessVM {
return_errno!(EINVAL, "New brk address is too high"); return_errno!(EINVAL, "New brk address is too high");
} }
if self.brk < new_brk { self.brk
unsafe { fill_zeros(self.brk, new_brk - self.brk) }; .fetch_update(|old_brk| Some(new_brk), Ordering::SeqCst, Ordering::SeqCst);
} Ok(new_brk)
self.brk = new_brk;
return Ok(new_brk);
} }
pub fn mmap( pub fn mmap(
&mut self, &self,
addr: usize, addr: usize,
size: usize, size: usize,
perms: VMPerms, perms: VMPerms,
@ -323,12 +317,12 @@ impl ProcessVM {
.perms(perms) .perms(perms)
.initializer(initializer) .initializer(initializer)
.build()?; .build()?;
let mmap_addr = self.mmap_manager.mmap(&mmap_options)?; let mmap_addr = self.mmap_manager.lock().unwrap().mmap(&mmap_options)?;
Ok(mmap_addr) Ok(mmap_addr)
} }
pub fn mremap( pub fn mremap(
&mut self, &self,
old_addr: usize, old_addr: usize,
old_size: usize, old_size: usize,
new_size: usize, new_size: usize,
@ -341,29 +335,36 @@ impl ProcessVM {
} }
let mremap_option = VMRemapOptions::new(old_addr, old_size, new_size, flags)?; let mremap_option = VMRemapOptions::new(old_addr, old_size, new_size, flags)?;
self.mmap_manager.mremap(&mremap_option) self.mmap_manager.lock().unwrap().mremap(&mremap_option)
} }
pub fn munmap(&mut self, addr: usize, size: usize) -> Result<()> { pub fn munmap(&self, addr: usize, size: usize) -> Result<()> {
self.mmap_manager.munmap(addr, size) self.mmap_manager.lock().unwrap().munmap(addr, size)
} }
pub fn mprotect(&mut self, addr: usize, size: usize, perms: VMPerms) -> Result<()> { pub fn mprotect(&self, addr: usize, size: usize, perms: VMPerms) -> Result<()> {
let protect_range = VMRange::new_with_size(addr, size)?; let protect_range = VMRange::new_with_size(addr, size)?;
if !self.process_range.range().is_superset_of(&protect_range) { if !self.process_range.range().is_superset_of(&protect_range) {
return_errno!(ENOMEM, "invalid range"); return_errno!(ENOMEM, "invalid range");
} }
let mut mmap_manager = self.mmap_manager.lock().unwrap();
// TODO: support mprotect vm regions in addition to mmap // TODO: support mprotect vm regions in addition to mmap
if !self.mmap_manager.range().is_superset_of(&protect_range) { if !mmap_manager.range().is_superset_of(&protect_range) {
warn!("Do not support mprotect memory outside the mmap region yet"); warn!("Do not support mprotect memory outside the mmap region yet");
return Ok(()); return Ok(());
} }
self.mmap_manager.mprotect(addr, size, perms) mmap_manager.mprotect(addr, size, perms)
} }
pub fn find_mmap_region(&self, addr: usize) -> Result<&VMRange> { // Return: a copy of the found region
self.mmap_manager.find_mmap_region(addr) pub fn find_mmap_region(&self, addr: usize) -> Result<VMRange> {
self.mmap_manager
.lock()
.unwrap()
.find_mmap_region(addr)
.map(|range_ref| *range_ref)
} }
} }
@ -433,11 +434,3 @@ impl Default for MRemapFlags {
MRemapFlags::None MRemapFlags::None
} }
} }
unsafe fn fill_zeros(addr: usize, size: usize) {
let ptr = addr as *mut u8;
let buf = std::slice::from_raw_parts_mut(ptr, size);
for b in buf {
*b = 0;
}
}