diff --git a/src/libos/src/net/socket/host/mod.rs b/src/libos/src/net/socket/host/mod.rs index fe85e2d8..2b7606f2 100644 --- a/src/libos/src/net/socket/host/mod.rs +++ b/src/libos/src/net/socket/host/mod.rs @@ -4,6 +4,9 @@ use std::mem; use atomic::Atomic; +use self::recv::Receiver; +use self::send::Sender; + use super::*; use crate::fs::{ occlum_ocall_ioctl, AccessMode, CreationFlags, File, FileRef, HostFd, IoEvents, IoNotifier, @@ -15,12 +18,16 @@ mod recv; mod send; mod socket_file; +pub const SEND_BUF_SIZE: usize = 128 * 1024; +pub const RECV_BUF_SIZE: usize = 128 * 1024; /// Native linux socket #[derive(Debug)] pub struct HostSocket { host_fd: HostFd, host_events: Atomic, notifier: IoNotifier, + sender: SgxMutex, + receiver: SgxMutex, } impl HostSocket { @@ -36,17 +43,21 @@ impl HostSocket { protocol )) as FileDesc; let host_fd = HostFd::new(raw_host_fd); - Ok(HostSocket::from_host_fd(host_fd)) + Ok(HostSocket::from_host_fd(host_fd)?) } - fn from_host_fd(host_fd: HostFd) -> HostSocket { + fn from_host_fd(host_fd: HostFd) -> Result { let host_events = Atomic::new(IoEvents::empty()); let notifier = IoNotifier::new(); - Self { + let sender = SgxMutex::new(Sender::new()?); + let receiver = SgxMutex::new(Receiver::new()?); + Ok(Self { host_fd, host_events, notifier, - } + sender, + receiver, + }) } pub fn bind(&self, addr: &SockAddr) -> Result<()> { @@ -83,7 +94,7 @@ impl HostSocket { } else { None }; - Ok((HostSocket::from_host_fd(host_fd), addr_option)) + Ok((HostSocket::from_host_fd(host_fd)?, addr_option)) } pub fn connect(&self, addr: &Option) -> Result<()> { diff --git a/src/libos/src/net/socket/host/recv.rs b/src/libos/src/net/socket/host/recv.rs index 63824f65..3d8ebab9 100644 --- a/src/libos/src/net/socket/host/recv.rs +++ b/src/libos/src/net/socket/host/recv.rs @@ -1,6 +1,18 @@ use super::*; use crate::untrusted::{SliceAsMutPtrAndLen, SliceAsPtrAndLen, UntrustedSliceAlloc}; +#[derive(Debug)] +pub struct Receiver { + alloc: UntrustedSliceAlloc, +} + +impl Receiver { + pub fn new() -> Result { + let alloc = UntrustedSliceAlloc::new(RECV_BUF_SIZE)?; + Ok(Self { alloc }) + } +} + impl HostSocket { pub fn recv(&self, buf: &mut [u8], flags: RecvFlags) -> Result { let (bytes_recvd, _) = self.recvfrom(buf, flags)?; @@ -32,11 +44,24 @@ impl HostSocket { mut control: Option<&mut [u8]>, ) -> Result<(usize, usize, usize, MsgHdrFlags)> { let data_length = data.iter().map(|s| s.len()).sum(); - let u_allocator = UntrustedSliceAlloc::new(data_length)?; + let mut receiver: SgxMutexGuard<'_, Receiver>; + let mut ocall_alloc; + // Allocated slice in untrusted memory region + let u_allocator = if data_length > RECV_BUF_SIZE { + // Ocall allocator + ocall_alloc = UntrustedSliceAlloc::new(data_length)?; + &mut ocall_alloc + } else { + // Inner allocator, lock buffer until recv ocall completion + receiver = self.receiver.lock().unwrap(); + &mut receiver.alloc + }; + let mut u_data = { let mut bufs = Vec::new(); for ref buf in data.iter() { - bufs.push(u_allocator.new_slice_mut(buf.len())?); + let u_slice = u_allocator.new_slice_mut(buf.len())?; + bufs.push(u_slice); } bufs }; @@ -52,6 +77,7 @@ impl HostSocket { break; } } + u_allocator.reset(); Ok(retval) } diff --git a/src/libos/src/net/socket/host/send.rs b/src/libos/src/net/socket/host/send.rs index 15487a75..0a65a53b 100644 --- a/src/libos/src/net/socket/host/send.rs +++ b/src/libos/src/net/socket/host/send.rs @@ -1,5 +1,17 @@ use super::*; +#[derive(Debug)] +pub struct Sender { + alloc: UntrustedSliceAlloc, +} + +impl Sender { + pub fn new() -> Result { + let alloc = UntrustedSliceAlloc::new(SEND_BUF_SIZE)?; + Ok(Self { alloc }) + } +} + impl HostSocket { pub fn send(&self, buf: &[u8], flags: SendFlags) -> Result { self.sendto(buf, flags, &None) @@ -24,16 +36,31 @@ impl HostSocket { control: Option<&[u8]>, ) -> Result { let data_length = data.iter().map(|s| s.len()).sum(); - let u_allocator = UntrustedSliceAlloc::new(data_length)?; + let mut sender: SgxMutexGuard<'_, Sender>; + let mut ocall_alloc; + // Allocated slice in untrusted memory region + let u_allocator = if data_length > SEND_BUF_SIZE { + // Ocall allocator + ocall_alloc = UntrustedSliceAlloc::new(data_length)?; + &mut ocall_alloc + } else { + // Inner allocator, lock buffer until send ocall completion + sender = self.sender.lock().unwrap(); + &mut sender.alloc + }; + let u_data = { let mut bufs = Vec::new(); for buf in data { - bufs.push(u_allocator.new_slice(buf)?); + let u_slice = u_allocator.new_slice(buf)?; + bufs.push(u_slice); } bufs }; - self.do_sendmsg_untrusted_data(&u_data, flags, name, control) + let retval = self.do_sendmsg_untrusted_data(&u_data, flags, name, control); + u_allocator.reset(); + retval } fn do_sendmsg_untrusted_data( diff --git a/src/libos/src/untrusted/slice_alloc.rs b/src/libos/src/untrusted/slice_alloc.rs index 84d1ae3a..b69ed74d 100644 --- a/src/libos/src/untrusted/slice_alloc.rs +++ b/src/libos/src/untrusted/slice_alloc.rs @@ -1,5 +1,6 @@ use super::*; use std::alloc::{AllocError, Allocator, Layout}; +use std::fmt; use std::ops::{Deref, DerefMut}; use std::ptr::NonNull; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -21,6 +22,8 @@ pub struct UntrustedSliceAlloc { buf_pos: AtomicUsize, } +unsafe impl Send for UntrustedSliceAlloc {} + impl UntrustedSliceAlloc { pub fn new(buf_size: usize) -> Result { if buf_size == 0 { @@ -68,6 +71,10 @@ impl UntrustedSliceAlloc { let new_slice = unsafe { std::slice::from_raw_parts_mut(new_slice_ptr, new_slice_len) }; Ok(UntrustedSlice { slice: new_slice }) } + + pub fn reset(&mut self) { + self.buf_pos.store(0, Ordering::Relaxed); + } } impl Drop for UntrustedSliceAlloc { @@ -84,6 +91,15 @@ impl Drop for UntrustedSliceAlloc { } } +impl Debug for UntrustedSliceAlloc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UntrustedSliceAlloc") + .field("buf size", &self.buf_size) + .field("buf pos", &self.buf_pos.load(Ordering::Relaxed)) + .finish() + } +} + pub struct UntrustedSlice<'a> { slice: &'a mut [u8], }