From e352a190ea63cc9370371026e9bf620f92b24b41 Mon Sep 17 00:00:00 2001 From: He Sun Date: Wed, 15 Jan 2020 17:42:58 +0800 Subject: [PATCH] Optimize the perf of sendmsg/recvmsg by allocating untrusted buffers directly It is slow to allocate big buffers using SGX SDK's malloc. Even worse, it consumes a large amount of precious trusted memory inside enclaves. This commit avoids using trusted buffers and allocates untrusted buffers for sendmsg/recvmsg directly via OCall, thus improving the performance of sendmsg/recvmsg. Note that this optimization does not affect the security of network data as it has to be sent/received via OCalls. --- src/Enclave.edl | 11 ++-- src/libos/src/error/to_errno.rs | 12 ++++ src/libos/src/lib.rs | 3 + src/libos/src/net/iovs.rs | 74 +++++++++++------------ src/libos/src/net/mod.rs | 3 +- src/libos/src/net/socket_file/recv.rs | 62 ++++++++++++-------- src/libos/src/net/socket_file/send.rs | 58 ++++++++++-------- src/libos/src/prelude.rs | 26 --------- src/libos/src/untrusted/alloc.rs | 62 ++++++++++++++++++++ src/libos/src/untrusted/mod.rs | 10 ++++ src/libos/src/untrusted/slice_alloc.rs | 81 ++++++++++++++++++++++++++ src/libos/src/untrusted/slice_ext.rs | 76 ++++++++++++++++++++++++ src/pal/include/edl/occlum_edl_types.h | 1 + src/pal/src/ocalls/mem.c | 27 +++++++++ src/pal/src/ocalls/net.c | 20 ++----- test/client/main.c | 7 +++ test/server/main.c | 33 +++++++---- 17 files changed, 426 insertions(+), 140 deletions(-) create mode 100644 src/libos/src/untrusted/alloc.rs create mode 100644 src/libos/src/untrusted/mod.rs create mode 100644 src/libos/src/untrusted/slice_alloc.rs create mode 100644 src/libos/src/untrusted/slice_ext.rs create mode 100644 src/pal/src/ocalls/mem.c diff --git a/src/Enclave.edl b/src/Enclave.edl index 2acc43ce..faa833ef 100644 --- a/src/Enclave.edl +++ b/src/Enclave.edl @@ -57,6 +57,9 @@ enclave { void occlum_ocall_sync(void); + void* occlum_ocall_posix_memalign(size_t alignment, size_t size); + void occlum_ocall_free([user_check] void* ptr); + void occlum_ocall_sched_yield(void); int occlum_ocall_sched_getaffinity( int host_tid, @@ -87,8 +90,8 @@ enclave { int sockfd, [in, size=msg_namelen] const void* msg_name, socklen_t msg_namelen, - [in, size=buf_len] const void* buf, - size_t buf_len, + [in, count=msg_iovlen] const struct iovec* msg_iov, + size_t msg_iovlen, [in, size=msg_controllen] const void* msg_control, size_t msg_controllen, int flags @@ -98,8 +101,8 @@ enclave { [out, size=msg_namelen] void *msg_name, socklen_t msg_namelen, [out] socklen_t* msg_namelen_recv, - [out, size=buf_len] void* buf, - size_t buf_len, + [in, count=msg_iovlen] struct iovec* msg_iov, + size_t msg_iovlen, [out, size=msg_controllen] void *msg_control, size_t msg_controllen, [out] size_t* msg_controllen_recv, diff --git a/src/libos/src/error/to_errno.rs b/src/libos/src/error/to_errno.rs index 90735ae2..5735459e 100644 --- a/src/libos/src/error/to_errno.rs +++ b/src/libos/src/error/to_errno.rs @@ -98,3 +98,15 @@ impl ToErrno for rcore_fs::vfs::FsError { } } } + +impl ToErrno for std::alloc::AllocErr { + fn errno(&self) -> Errno { + ENOMEM + } +} + +impl ToErrno for std::alloc::LayoutErr { + fn errno(&self) -> Errno { + EINVAL + } +} diff --git a/src/libos/src/lib.rs b/src/libos/src/lib.rs index b3fc4be8..584dce93 100644 --- a/src/libos/src/lib.rs +++ b/src/libos/src/lib.rs @@ -7,6 +7,8 @@ #![feature(core_intrinsics)] #![feature(stmt_expr_attributes)] #![feature(atomic_min_max)] +#![feature(no_more_cas)] +#![feature(alloc_layout_extra)] #[macro_use] extern crate alloc; @@ -59,6 +61,7 @@ mod net; mod process; mod syscall; mod time; +mod untrusted; mod util; mod vm; diff --git a/src/libos/src/net/iovs.rs b/src/libos/src/net/iovs.rs index 42b60d08..fb11615e 100644 --- a/src/libos/src/net/iovs.rs +++ b/src/libos/src/net/iovs.rs @@ -1,6 +1,8 @@ //! I/O vectors use super::*; +use crate::untrusted::SliceAsPtrAndLen; +use std::iter::Iterator; /// A memory safe, immutable version of C iovec array pub struct Iovs<'a> { @@ -19,19 +21,6 @@ impl<'a> Iovs<'a> { pub fn total_bytes(&self) -> usize { self.iovs.iter().map(|s| s.len()).sum() } - - pub fn gather_to_vec(&self) -> Vec { - Self::gather_slices_to_vec(&self.iovs[..]) - } - - fn gather_slices_to_vec(slices: &[&[u8]]) -> Vec { - let vec_len = slices.iter().map(|slice| slice.len()).sum(); - let mut vec = Vec::with_capacity(vec_len); - for slice in slices { - vec.extend_from_slice(slice); - } - vec - } } /// A memory safe, mutable version of C iovec array @@ -59,30 +48,41 @@ impl<'a> IovsMut<'a> { self.iovs.iter().map(|s| s.len()).sum() } - pub fn gather_to_vec(&self) -> Vec { - Iovs::gather_slices_to_vec(self.as_slices()) - } - - pub fn scatter_copy_from(&mut self, data: &[u8]) -> usize { - let mut total_nbytes = 0; - let mut remain_slice = data; - for iov in &mut self.iovs { - if remain_slice.len() == 0 { - break; - } - - let copy_nbytes = remain_slice.len().min(iov.len()); - let dst_slice = unsafe { - debug_assert!(iov.len() >= copy_nbytes); - iov.get_unchecked_mut(..copy_nbytes) - }; - let (src_slice, _remain_slice) = remain_slice.split_at(copy_nbytes); - dst_slice.copy_from_slice(src_slice); - - remain_slice = _remain_slice; - total_nbytes += copy_nbytes; + /// Copy as many bytes from an u8 iterator as possible + pub fn copy_from_iter<'b, T>(&mut self, src_iter: &mut T) -> usize + where + T: Iterator, + { + let mut bytes_copied = 0; + let mut dst_iter = self + .as_slices_mut() + .iter_mut() + .flat_map(|mut slice| slice.iter_mut()); + while let (Some(mut d), Some(s)) = (dst_iter.next(), src_iter.next()) { + *d = *s; + bytes_copied += 1; } - debug_assert!(remain_slice.len() == 0); - total_nbytes + bytes_copied + } +} + +/// An extention trait that converts slice to libc::iovec +pub trait SliceAsLibcIovec { + fn as_libc_iovec(&self) -> libc::iovec; +} + +impl SliceAsLibcIovec for &[u8] { + fn as_libc_iovec(&self) -> libc::iovec { + let (iov_base, iov_len) = self.as_ptr_and_len(); + let iov_base = iov_base as *mut u8 as *mut c_void; + libc::iovec { iov_base, iov_len } + } +} + +impl SliceAsLibcIovec for &mut [u8] { + fn as_libc_iovec(&self) -> libc::iovec { + let (iov_base, iov_len) = self.as_ptr_and_len(); + let iov_base = iov_base as *mut u8 as *mut c_void; + libc::iovec { iov_base, iov_len } } } diff --git a/src/libos/src/net/mod.rs b/src/libos/src/net/mod.rs index e1283731..d0ada01f 100644 --- a/src/libos/src/net/mod.rs +++ b/src/libos/src/net/mod.rs @@ -1,4 +1,5 @@ use super::*; +use std::*; mod iovs; mod msg; @@ -6,7 +7,7 @@ mod msg_flags; mod socket_file; mod syscalls; -pub use self::iovs::{Iovs, IovsMut}; +pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; pub use self::msg_flags::MsgFlags; pub use self::socket_file::{AsSocket, SocketFile}; diff --git a/src/libos/src/net/socket_file/recv.rs b/src/libos/src/net/socket_file/recv.rs index 064f7f2b..7f688ab1 100644 --- a/src/libos/src/net/socket_file/recv.rs +++ b/src/libos/src/net/socket_file/recv.rs @@ -1,4 +1,5 @@ use super::*; +use crate::untrusted::{SliceAsMutPtrAndLen, SliceAsPtrAndLen, UntrustedSliceAlloc}; impl SocketFile { // TODO: need sockaddr type to implement send/sento @@ -19,37 +20,48 @@ impl SocketFile { }*/ pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: MsgFlags) -> Result { - // Allocate a single data buffer is big enough for all iovecs of msg. - // This is a workaround for the OCall that takes only one data buffer. - let mut data_buf = { - let data_buf_len = msg.get_iovs().total_bytes(); - let data_vec = vec![0; data_buf_len]; - data_vec.into_boxed_slice() - }; + // Alloc untrusted iovecs to receive data via OCall + let msg_iov = msg.get_iovs(); + let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; + let mut u_slices = msg_iov + .as_slices() + .iter() + .map(|slice| { + u_slice_alloc + .new_slice_mut(slice.len()) + .expect("unexpected out of memory error in UntrustedSliceAlloc") + }) + .collect(); + let mut u_iovs = IovsMut::new(u_slices); + // Do OCall-based recvmsg let (bytes_recvd, namelen_recvd, controllen_recvd, flags_recvd) = { - let data = &mut data_buf[..]; // Acquire mutable references to the name and control buffers let (name, control) = msg.get_name_and_control_mut(); // Fill the data, the name, and the control buffers - self.do_recvmsg(data, flags, name, control)? + self.do_recvmsg(u_iovs.as_slices_mut(), flags, name, control)? }; - // Update the lengths and flags + // Update the output lengths and flags msg.set_name_len(namelen_recvd)?; msg.set_control_len(controllen_recvd)?; msg.set_flags(flags_recvd); - let recv_data = &data_buf[..bytes_recvd]; - // TODO: avoid this one extra copy due to the intermediate data buffer - msg.get_iovs_mut().scatter_copy_from(recv_data); + // Copy data from untrusted iovecs into the output iovecs + let mut msg_iov = msg.get_iovs_mut(); + let mut u_iovs_iter = u_iovs + .as_slices() + .iter() + .flat_map(|slice| slice.iter()) + .take(bytes_recvd); + msg_iov.copy_from_iter(&mut u_iovs_iter); Ok(bytes_recvd) } fn do_recvmsg( &self, - data: &mut [u8], + data: &mut [&mut [u8]], flags: MsgFlags, mut name: Option<&mut [u8]>, mut control: Option<&mut [u8]>, @@ -58,14 +70,15 @@ impl SocketFile { // Host socket fd let host_fd = self.host_fd; // Name - let (msg_name, msg_namelen) = name.get_mut_ptr_and_len(); + let (msg_name, msg_namelen) = name.as_mut_ptr_and_len(); let msg_name = msg_name as *mut c_void; let mut msg_namelen_recvd = 0_u32; - // Data - let msg_data = data.as_mut_ptr(); - let msg_datalen = data.len(); + // Iovs + let mut raw_iovs: Vec = + data.iter().map(|slice| slice.as_libc_iovec()).collect(); + let (msg_iov, msg_iovlen) = raw_iovs.as_mut_slice().as_mut_ptr_and_len(); // Control - let (msg_control, msg_controllen) = control.get_mut_ptr_and_len(); + let (msg_control, msg_controllen) = control.as_mut_ptr_and_len(); let msg_control = msg_control as *mut c_void; let mut msg_controllen_recvd = 0; // Flags @@ -81,8 +94,8 @@ impl SocketFile { msg_name, msg_namelen as u32, &mut msg_namelen_recvd as *mut u32, - msg_data, - msg_datalen, + msg_iov, + msg_iovlen, msg_control, msg_controllen, &mut msg_controllen_recvd as *mut usize, @@ -103,7 +116,8 @@ impl SocketFile { let retval = retval as usize; // Check bytes_recvd returned from outside the enclave - assert!(retval <= data.len()); + let max_bytes_recvd = data.iter().map(|x| x.len()).sum(); + assert!(retval <= max_bytes_recvd); retval }; let msg_namelen_recvd = msg_namelen_recvd as usize; @@ -127,8 +141,8 @@ extern "C" { msg_name: *mut c_void, msg_namelen: libc::socklen_t, msg_namelen_recv: *mut libc::socklen_t, - msg_data: *mut u8, - msg_data: size_t, + msg_data: *mut libc::iovec, + msg_datalen: size_t, msg_control: *mut c_void, msg_controllen: size_t, msg_controllen_recv: *mut size_t, diff --git a/src/libos/src/net/socket_file/send.rs b/src/libos/src/net/socket_file/send.rs index dfd3acf7..635f8d46 100644 --- a/src/libos/src/net/socket_file/send.rs +++ b/src/libos/src/net/socket_file/send.rs @@ -1,4 +1,5 @@ use super::*; +use crate::untrusted::{SliceAsMutPtrAndLen, SliceAsPtrAndLen, UntrustedSliceAlloc}; impl SocketFile { // TODO: need sockaddr type to implement send/sento @@ -18,44 +19,55 @@ impl SocketFile { */ pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: MsgFlags) -> Result { - // Copy data in iovs into a single buffer - let data_buf = msg.get_iovs().gather_to_vec(); + // Copy message's iovecs into untrusted iovecs + let msg_iov = msg.get_iovs(); + let u_slice_alloc = UntrustedSliceAlloc::new(msg_iov.total_bytes())?; + let u_slices = msg_iov + .as_slices() + .iter() + .map(|src_slice| { + u_slice_alloc + .new_slice(src_slice) + .expect("unexpected out of memory") + }) + .collect(); + let u_iovs = Iovs::new(u_slices); - self.do_sendmsg(&data_buf[..], flags, msg.get_name(), msg.get_control()) + self.do_sendmsg(u_iovs.as_slices(), flags, msg.get_name(), msg.get_control()) } fn do_sendmsg( &self, - data: &[u8], + data: &[&[u8]], flags: MsgFlags, name: Option<&[u8]>, control: Option<&[u8]>, ) -> Result { - let bytes_sent = try_libc!({ - // Prepare the arguments for OCall - let mut retval: isize = 0; - // Host socket fd - let host_fd = self.host_fd; - // Name - let (msg_name, msg_namelen) = name.get_ptr_and_len(); - let msg_name = msg_name as *const c_void; - // Data - let msg_data = data.as_ptr(); - let msg_datalen = data.len(); - // Control - let (msg_control, msg_controllen) = control.get_ptr_and_len(); - let msg_control = msg_control as *const c_void; - // Flags - let flags = flags.to_u32() as i32; + // Prepare the arguments for OCall + let mut retval: isize = 0; + // Host socket fd + let host_fd = self.host_fd; + // Name + let (msg_name, msg_namelen) = name.as_ptr_and_len(); + let msg_name = msg_name as *const c_void; + // Iovs + let raw_iovs: Vec = data.iter().map(|slice| slice.as_libc_iovec()).collect(); + let (msg_iov, msg_iovlen) = raw_iovs.as_slice().as_ptr_and_len(); + // Control + let (msg_control, msg_controllen) = control.as_ptr_and_len(); + let msg_control = msg_control as *const c_void; + // Flags + let flags = flags.to_u32() as i32; + let bytes_sent = try_libc!({ // Do OCall let status = occlum_ocall_sendmsg( &mut retval as *mut isize, host_fd, msg_name, msg_namelen as u32, - msg_data, - msg_datalen, + msg_iov, + msg_iovlen, msg_control, msg_controllen, flags, @@ -75,7 +87,7 @@ extern "C" { fd: c_int, msg_name: *const c_void, msg_namelen: libc::socklen_t, - msg_data: *const u8, + msg_data: *const libc::iovec, msg_datalen: size_t, msg_control: *const c_void, msg_controllen: size_t, diff --git a/src/libos/src/prelude.rs b/src/libos/src/prelude.rs index 9205ff78..d3658527 100644 --- a/src/libos/src/prelude.rs +++ b/src/libos/src/prelude.rs @@ -31,29 +31,3 @@ pub fn align_down(addr: usize, align: usize) -> usize { pub fn unbox(value: Box) -> T { *value } - -pub trait SliceOptionExt { - fn get_ptr_and_len(&self) -> (*const T, usize); -} - -impl SliceOptionExt for Option<&[T]> { - fn get_ptr_and_len(&self) -> (*const T, usize) { - match self { - Some(self_slice) => (self_slice.as_ptr(), self_slice.len()), - None => (std::ptr::null(), 0), - } - } -} - -pub trait MutSliceOptionExt { - fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize); -} - -impl MutSliceOptionExt for Option<&mut [T]> { - fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize) { - match self { - Some(self_slice) => (self_slice.as_mut_ptr(), self_slice.len()), - None => (std::ptr::null_mut(), 0), - } - } -} diff --git a/src/libos/src/untrusted/alloc.rs b/src/libos/src/untrusted/alloc.rs new file mode 100644 index 00000000..1645e749 --- /dev/null +++ b/src/libos/src/untrusted/alloc.rs @@ -0,0 +1,62 @@ +use super::*; +use std::alloc::{Alloc, AllocErr, Layout}; +use std::ptr::{self, NonNull}; + +/// The global memory allocator for untrusted memory +pub static mut UNTRUSTED_ALLOC: UntrustedAlloc = UntrustedAlloc; + +pub struct UntrustedAlloc; + +unsafe impl Alloc for UntrustedAlloc { + unsafe fn alloc(&mut self, layout: Layout) -> std::result::Result, AllocErr> { + if layout.size() == 0 { + return Err(AllocErr); + } + + // Do OCall to allocate the untrusted memory according to the given layout + let layout = layout + .align_to(std::mem::size_of::<*const c_void>()) + .unwrap(); + let mem_ptr = { + let mut mem_ptr: *mut c_void = ptr::null_mut(); + let sgx_status = unsafe { + occlum_ocall_posix_memalign(&mut mem_ptr as *mut _, layout.align(), layout.size()) + }; + debug_assert!(sgx_status == sgx_status_t::SGX_SUCCESS); + mem_ptr + } as *mut u8; + if mem_ptr == std::ptr::null_mut() { + return Err(AllocErr); + } + + // Sanity checks + // Post-condition 1: alignment + debug_assert!(mem_ptr as usize % layout.align() == 0); + // Post-condition 2: out-of-enclave + assert!(sgx_trts::trts::rsgx_raw_is_outside_enclave( + mem_ptr as *const u8, + layout.size() + )); + Ok(NonNull::new(mem_ptr).unwrap()) + } + + unsafe fn dealloc(&mut self, ptr: NonNull, layout: Layout) { + // Pre-condition: out-of-enclave + debug_assert!(sgx_trts::trts::rsgx_raw_is_outside_enclave( + ptr.as_ptr(), + layout.size() + )); + + let sgx_status = unsafe { occlum_ocall_free(ptr.as_ptr() as *mut c_void) }; + debug_assert!(sgx_status == sgx_status_t::SGX_SUCCESS); + } +} + +extern "C" { + fn occlum_ocall_posix_memalign( + ptr: *mut *mut c_void, + align: usize, // must be power of two and a multiple of sizeof(void*) + size: usize, + ) -> sgx_status_t; + fn occlum_ocall_free(ptr: *mut c_void) -> sgx_status_t; +} diff --git a/src/libos/src/untrusted/mod.rs b/src/libos/src/untrusted/mod.rs new file mode 100644 index 00000000..9e35d794 --- /dev/null +++ b/src/libos/src/untrusted/mod.rs @@ -0,0 +1,10 @@ +/// Manipulate and access untrusted memory or functionalities safely +mod alloc; +mod slice_alloc; +mod slice_ext; + +use super::*; + +pub use self::alloc::UNTRUSTED_ALLOC; +pub use self::slice_alloc::UntrustedSliceAlloc; +pub use self::slice_ext::{SliceAsMutPtrAndLen, SliceAsPtrAndLen}; diff --git a/src/libos/src/untrusted/slice_alloc.rs b/src/libos/src/untrusted/slice_alloc.rs new file mode 100644 index 00000000..e4e33c98 --- /dev/null +++ b/src/libos/src/untrusted/slice_alloc.rs @@ -0,0 +1,81 @@ +use super::*; +use std::alloc::{Alloc, AllocErr, Layout}; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// An memory allocator for slices, backed by a fixed-size, untrusted buffer +pub struct UntrustedSliceAlloc { + /// The pointer to the untrusted buffer + buf_ptr: *mut u8, + /// The size of the untrusted buffer + buf_size: usize, + /// The next position to allocate new slice + /// New slices must be allocated from [buf_ptr + buf_pos, buf_ptr + buf_size) + buf_pos: AtomicUsize, +} + +impl UntrustedSliceAlloc { + pub fn new(buf_size: usize) -> Result { + if buf_size == 0 { + // Create a dummy object + return Ok(Self { + buf_ptr: std::ptr::null_mut(), + buf_size: 0, + buf_pos: AtomicUsize::new(0), + }); + } + + let layout = Layout::from_size_align(buf_size, 1)?; + let buf_ptr = unsafe { UNTRUSTED_ALLOC.alloc(layout)?.as_ptr() }; + let buf_pos = AtomicUsize::new(0); + Ok(Self { + buf_ptr, + buf_size, + buf_pos, + }) + } + + pub fn new_slice(&self, src_slice: &[u8]) -> Result<&[u8]> { + let mut new_slice = self.new_slice_mut(src_slice.len())?; + new_slice.copy_from_slice(src_slice); + Ok(new_slice) + } + + pub fn new_slice_mut(&self, new_slice_len: usize) -> Result<&mut [u8]> { + let new_slice_ptr = { + // Move self.buf_pos forward if enough space _atomically_. + let old_pos = self + .buf_pos + .fetch_update( + |old_pos| { + let new_pos = old_pos + new_slice_len; + if new_pos <= self.buf_size { + Some(new_pos) + } else { + None + } + }, + Ordering::SeqCst, + Ordering::SeqCst, + ) + .map_err(|e| errno!(ENOMEM, "No enough space"))?; + unsafe { self.buf_ptr.add(old_pos) } + }; + let new_slice = unsafe { std::slice::from_raw_parts_mut(new_slice_ptr, new_slice_len) }; + Ok(new_slice) + } +} + +impl Drop for UntrustedSliceAlloc { + fn drop(&mut self) { + // Do nothing for the dummy case + if self.buf_size == 0 { + return; + } + + let layout = Layout::from_size_align(self.buf_size, 1).unwrap(); + unsafe { + UNTRUSTED_ALLOC.dealloc(NonNull::new(self.buf_ptr).unwrap(), layout); + } + } +} diff --git a/src/libos/src/untrusted/slice_ext.rs b/src/libos/src/untrusted/slice_ext.rs new file mode 100644 index 00000000..a4e82907 --- /dev/null +++ b/src/libos/src/untrusted/slice_ext.rs @@ -0,0 +1,76 @@ +/// Extension traits for slices +use super::*; +use std::ptr; + +/// An extension trait for slice to get its _const_ pointer and length. +/// +/// If the length is zero, then the pointer is null. This trait is handy when +/// it comes to converting slices to pointers and lengths for OCalls. +pub trait SliceAsPtrAndLen { + fn as_ptr_and_len(&self) -> (*const T, usize); +} + +impl SliceAsPtrAndLen for Option<&[T]> { + fn as_ptr_and_len(&self) -> (*const T, usize) { + match self { + Some(self_slice) => self_slice.as_ptr_and_len(), + None => (std::ptr::null(), 0), + } + } +} + +impl SliceAsPtrAndLen for Option<&mut [T]> { + fn as_ptr_and_len(&self) -> (*const T, usize) { + match self { + Some(self_slice) => self_slice.as_ptr_and_len(), + None => (std::ptr::null(), 0), + } + } +} + +impl SliceAsPtrAndLen for &[T] { + fn as_ptr_and_len(&self) -> (*const T, usize) { + if self.len() > 0 { + (self.as_ptr(), self.len()) + } else { + (ptr::null(), 0) + } + } +} + +impl SliceAsPtrAndLen for &mut [T] { + fn as_ptr_and_len(&self) -> (*const T, usize) { + if self.len() > 0 { + (self.as_ptr(), self.len()) + } else { + (ptr::null(), 0) + } + } +} + +/// An extension trait for slice to get its _mutable_ pointer and length. +/// +/// If the length is zero, then the pointer is null. This trait is handy when +/// it comes to converting slices to pointers and lengths for OCalls. +pub trait SliceAsMutPtrAndLen { + fn as_mut_ptr_and_len(&mut self) -> (*mut T, usize); +} + +impl SliceAsMutPtrAndLen for Option<&mut [T]> { + fn as_mut_ptr_and_len(&mut self) -> (*mut T, usize) { + match self { + Some(self_slice) => self_slice.as_mut_ptr_and_len(), + None => (std::ptr::null_mut(), 0), + } + } +} + +impl SliceAsMutPtrAndLen for &mut [T] { + fn as_mut_ptr_and_len(&mut self) -> (*mut T, usize) { + if self.len() > 0 { + (self.as_mut_ptr(), self.len()) + } else { + (ptr::null_mut(), 0) + } + } +} diff --git a/src/pal/include/edl/occlum_edl_types.h b/src/pal/include/edl/occlum_edl_types.h index 03e8a522..209eb7f8 100644 --- a/src/pal/include/edl/occlum_edl_types.h +++ b/src/pal/include/edl/occlum_edl_types.h @@ -3,5 +3,6 @@ #include // import struct timespec #include // import struct timeval +#include // import struct iovec #endif /* __OCCLUM_EDL_TYPES__ */ diff --git a/src/pal/src/ocalls/mem.c b/src/pal/src/ocalls/mem.c new file mode 100644 index 00000000..bfeebaa8 --- /dev/null +++ b/src/pal/src/ocalls/mem.c @@ -0,0 +1,27 @@ +#include +#include "ocalls.h" + +void* occlum_ocall_posix_memalign(size_t alignment, size_t size) { + void* ptr = NULL; + int ret = posix_memalign(&ptr, alignment, size); + if (ret == 0) { + return ptr; + } + + // Handle errors + switch(ret) { + case ENOMEM: + PAL_ERROR("Out of memory on the untrusted side"); + break; + case EINVAL: + PAL_ERROR("Invalid arguments given to occlum_ocall_posix_memalign"); + break; + default: + PAL_ERROR("Unexpected error in occlum_ocall_posix_memalign"); + } + return NULL; +} + +void occlum_ocall_free(void* ptr) { + free(ptr); +} diff --git a/src/pal/src/ocalls/net.c b/src/pal/src/ocalls/net.c index 92a6e3ea..704889df 100644 --- a/src/pal/src/ocalls/net.c +++ b/src/pal/src/ocalls/net.c @@ -6,20 +6,16 @@ ssize_t occlum_ocall_sendmsg(int sockfd, const void *msg_name, socklen_t msg_namelen, - const void *buf, - size_t buf_len, + const struct iovec *msg_iov, + size_t msg_iovlen, const void *msg_control, size_t msg_controllen, int flags) { - struct iovec msg_iov = { .iov_base = (void*)buf, .iov_len = buf_len }; - struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL; - size_t msg_iovlen = buf != NULL ? 1 : 0; - struct msghdr msg = { (void*) msg_name, msg_namelen, - p_msg_iov, + (struct iovec *) msg_iov, msg_iovlen, (void*) msg_control, msg_controllen, @@ -32,22 +28,18 @@ ssize_t occlum_ocall_recvmsg(int sockfd, void *msg_name, socklen_t msg_namelen, socklen_t* msg_namelen_recv, - void *buf, - size_t buf_len, + struct iovec *msg_iov, + size_t msg_iovlen, void *msg_control, size_t msg_controllen, size_t* msg_controllen_recv, int* msg_flags_recv, int flags) { - struct iovec msg_iov = { .iov_base = buf, .iov_len = buf_len }; - struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL; - size_t msg_iovlen = buf != NULL ? 1 : 0; - struct msghdr msg = { msg_name, msg_namelen, - p_msg_iov, + msg_iov, msg_iovlen, msg_control, msg_controllen, diff --git a/test/client/main.c b/test/client/main.c index 9d5b75d1..01dbf35c 100644 --- a/test/client/main.c +++ b/test/client/main.c @@ -76,6 +76,13 @@ int client_sendmsg(int server_fd, char *buf) { ret = sendmsg(server_fd, &msg, 0); if (ret <= 0) THROW_ERROR("sendmsg failed"); + + msg.msg_iov = NULL; + msg.msg_iovlen = 0; + + ret = sendmsg(server_fd, &msg, 0); + if (ret != 0) + THROW_ERROR("empty sendmsg failed"); return ret; } diff --git a/test/server/main.c b/test/server/main.c index d1bd9f18..4fd47800 100644 --- a/test/server/main.c +++ b/test/server/main.c @@ -64,13 +64,13 @@ int connect_with_child(int port, int *child_pid) { int neogotiate_msg(int client_fd) { char buf[16]; - if (write(client_fd, ECHO_MSG, sizeof(ECHO_MSG)) < 0) + if (write(client_fd, ECHO_MSG, strlen(ECHO_MSG)) < 0) THROW_ERROR("write failed"); if (read(client_fd, buf, 16) < 0) THROW_ERROR("read failed"); - if (strncmp(buf, RESPONSE, sizeof(RESPONSE)) != 0) { + if (strncmp(buf, RESPONSE, strlen(RESPONSE)) != 0) { THROW_ERROR("msg recv mismatch"); } return 0; @@ -83,7 +83,7 @@ int server_recv(int client_fd) { if (recv(client_fd, buf, buf_size, 0) <= 0) THROW_ERROR("msg recv failed"); - if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) { + if (strncmp(buf, ECHO_MSG, strlen(ECHO_MSG)) != 0) { THROW_ERROR("msg recv mismatch"); } return 0; @@ -91,17 +91,21 @@ int server_recv(int client_fd) { int server_recvmsg(int client_fd) { int ret = 0; - const int buf_size = 1000; - char buf[buf_size]; + const int buf_size = 10; + char buf[3][buf_size]; struct msghdr msg; - struct iovec iov[1]; + struct iovec iov[3]; msg.msg_name = NULL; msg.msg_namelen = 0; - iov[0].iov_base = buf; + iov[0].iov_base = buf[0]; iov[0].iov_len = buf_size; + iov[1].iov_base = buf[1]; + iov[1].iov_len = buf_size; + iov[2].iov_base = buf[2]; + iov[2].iov_len = buf_size; msg.msg_iov = iov; - msg.msg_iovlen = 1; + msg.msg_iovlen = 3; msg.msg_control = 0; msg.msg_controllen = 0; msg.msg_flags = 0; @@ -110,11 +114,18 @@ int server_recvmsg(int client_fd) { if (ret <= 0) { THROW_ERROR("recvmsg failed"); } else { - if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) { - printf("recvmsg : %d, msg: %s\n", ret, buf); + if (strncmp(buf[0], ECHO_MSG, buf_size) != 0 && + strstr(ECHO_MSG, buf[1]) != NULL && + strstr(ECHO_MSG, buf[2]) != NULL) { + printf("recvmsg : %d, msg: %s, %s, %s\n", ret, buf[0], buf[1], buf[2]); THROW_ERROR("msg recvmsg mismatch"); } } + msg.msg_iov = NULL; + msg.msg_iovlen = 0; + ret = recvmsg(client_fd, &msg, 0); + if (ret != 0) + THROW_ERROR("recvmsg empty failed"); return ret; } @@ -157,7 +168,7 @@ int server_connectionless_recvmsg() { if (ret <= 0) { THROW_ERROR("recvmsg failed"); } else { - if (strncmp(buf, DEFAULT_MSG, sizeof(DEFAULT_MSG)) != 0) { + if (strncmp(buf, DEFAULT_MSG, strlen(DEFAULT_MSG)) != 0) { printf("recvmsg : %d, msg: %s\n", ret, buf); THROW_ERROR("msg recvmsg mismatch"); } else {