Add sendmsg/recvmsg syscalls

1. Add a separate net/ directory for the network subsystem;
2. Move some existing socket code to net/;
3. Implement sendmsg/recvmsg with OCalls;
4. Extend client/server test cases.
This commit is contained in:
He Sun 2019-11-13 12:59:18 +00:00 committed by Tate, Hongliang Tian
parent 2052447950
commit 0cef5b1b53
18 changed files with 1213 additions and 125 deletions

@ -40,5 +40,30 @@ enclave {
[out] sgx_report_t* qe_report,
[out, size=quote_buf_len] uint8_t* quote_buf,
uint32_t quote_buf_len);
int64_t ocall_sendmsg(
int sockfd,
[in, size=msg_namelen] const void* msg_name,
socklen_t msg_namelen,
[in, size=buf_len] const void* buf,
size_t buf_len,
[in, size=msg_controllen] const void* msg_control,
size_t msg_controllen,
int flags
) propagate_errno;
int64_t ocall_recvmsg(
int sockfd,
[out, size=msg_namelen] void *msg_name,
socklen_t msg_namelen,
[out] socklen_t* msg_namelen_recv,
[out, size=buf_len] void* buf,
size_t buf_len,
[out, size=msg_controllen] void *msg_control,
size_t msg_controllen,
[out] size_t* msg_controllen_recv,
[out] int* msg_flags_recv,
int flags
) propagate_errno;
};
};

@ -141,6 +141,7 @@ pub enum Errno {
EHWPOISON = 133,
// Note: always keep the last item in sync with ERRNO_MAX
}
const ERRNO_MIN: u32 = Errno::EPERM as u32;
const ERRNO_MAX: u32 = Errno::EHWPOISON as u32;
impl Errno {
@ -192,10 +193,8 @@ impl Errno {
}
impl From<u32> for Errno {
fn from(mut raw_errno: u32) -> Self {
if raw_errno > ERRNO_MAX {
raw_errno = 0;
}
fn from(raw_errno: u32) -> Self {
assert!(ERRNO_MIN <= raw_errno && raw_errno <= ERRNO_MAX);
unsafe { core::mem::transmute(raw_errno as u8) }
}
}

@ -44,7 +44,7 @@ macro_rules! return_errno {
macro_rules! try_libc {
($ret: expr) => {{
let ret = unsafe { $ret };
if ret == -1 {
if ret < 0 {
let errno = unsafe { libc::errno() };
return_errno!(Errno::from(errno as u32), "libc error");
}

@ -1,5 +1,6 @@
use super::*;
use net::{AsSocket, SocketFile};
use process::Process;
use rcore_fs::vfs::{FileType, FsError, INode, Metadata, Timespec};
use std::io::{Read, Seek, SeekFrom, Write};
@ -19,7 +20,6 @@ pub use self::io_multiplexing::*;
pub use self::ioctl::*;
pub use self::pipe::Pipe;
pub use self::root_inode::ROOT_INODE;
pub use self::socket_file::{AsSocket, SocketFile};
pub use self::unix_socket::{AsUnixSocket, UnixSocketFile};
use sgx_trts::libc::S_IWUSR;
use std::any::Any;
@ -39,7 +39,6 @@ mod ioctl;
mod pipe;
mod root_inode;
mod sgx_impl;
mod socket_file;
mod unix_socket;
pub fn do_open(path: &str, flags: u32, mode: u32) -> Result<FileDesc> {

@ -57,6 +57,7 @@ mod entry;
mod exception;
mod fs;
mod misc;
mod net;
mod process;
mod syscall;
mod time;

88
src/libos/src/net/iovs.rs Normal file

@ -0,0 +1,88 @@
//! I/O vectors
use super::*;
/// A memory safe, immutable version of C iovec array
pub struct Iovs<'a> {
iovs: Vec<&'a [u8]>,
}
impl<'a> Iovs<'a> {
pub fn new(slices: Vec<&'a [u8]>) -> Iovs {
Self { iovs: slices }
}
pub fn as_slices(&self) -> &[&[u8]] {
&self.iovs[..]
}
pub fn total_bytes(&self) -> usize {
self.iovs.iter().map(|s| s.len()).sum()
}
pub fn gather_to_vec(&self) -> Vec<u8> {
Self::gather_slices_to_vec(&self.iovs[..])
}
fn gather_slices_to_vec(slices: &[&[u8]]) -> Vec<u8> {
let vec_len = slices.iter().map(|slice| slice.len()).sum();
let mut vec = Vec::with_capacity(vec_len);
for slice in slices {
vec.extend_from_slice(slice);
}
vec
}
}
/// A memory safe, mutable version of C iovec array
pub struct IovsMut<'a> {
iovs: Vec<&'a mut [u8]>,
}
impl<'a> IovsMut<'a> {
pub fn new(slices: Vec<&'a mut [u8]>) -> Self {
Self { iovs: slices }
}
pub fn as_slices<'b>(&'b self) -> &'b [&'a [u8]] {
let slices_mut: &'b [&'a mut [u8]] = &self.iovs[..];
// We are "downgrading" _mutable_ slices to _immutable_ ones. It should be
// safe to do this transmute
unsafe { std::mem::transmute(slices_mut) }
}
pub fn as_slices_mut<'b>(&'b mut self) -> &'b mut [&'a mut [u8]] {
&mut self.iovs[..]
}
pub fn total_bytes(&self) -> usize {
self.iovs.iter().map(|s| s.len()).sum()
}
pub fn gather_to_vec(&self) -> Vec<u8> {
Iovs::gather_slices_to_vec(self.as_slices())
}
pub fn scatter_copy_from(&mut self, data: &[u8]) -> usize {
let mut total_nbytes = 0;
let mut remain_slice = data;
for iov in &mut self.iovs {
if remain_slice.len() == 0 {
break;
}
let copy_nbytes = remain_slice.len().min(iov.len());
let dst_slice = unsafe {
debug_assert!(iov.len() >= copy_nbytes);
iov.get_unchecked_mut(..copy_nbytes)
};
let (src_slice, _remain_slice) = remain_slice.split_at(copy_nbytes);
dst_slice.copy_from_slice(src_slice);
remain_slice = _remain_slice;
total_nbytes += copy_nbytes;
}
debug_assert!(remain_slice.len() == 0);
total_nbytes
}
}

13
src/libos/src/net/mod.rs Normal file

@ -0,0 +1,13 @@
use super::*;
mod iovs;
mod msg;
mod msg_flags;
mod socket_file;
mod syscalls;
pub use self::iovs::{Iovs, IovsMut};
pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut};
pub use self::msg_flags::MsgFlags;
pub use self::socket_file::{AsSocket, SocketFile};
pub use self::syscalls::{do_recvmsg, do_sendmsg};

231
src/libos/src/net/msg.rs Normal file

@ -0,0 +1,231 @@
/// Socket message and its flags.
use super::*;
/// C struct for a socket message with const pointers
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct msghdr {
pub msg_name: *const c_void,
pub msg_namelen: libc::socklen_t,
pub msg_iov: *const libc::iovec,
pub msg_iovlen: size_t,
pub msg_control: *const c_void,
pub msg_controllen: size_t,
pub msg_flags: c_int,
}
/// C struct for a socket message with mutable pointers
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct msghdr_mut {
pub msg_name: *mut c_void,
pub msg_namelen: libc::socklen_t,
pub msg_iov: *mut libc::iovec,
pub msg_iovlen: size_t,
pub msg_control: *mut c_void,
pub msg_controllen: size_t,
pub msg_flags: c_int,
}
/// MsgHdr is a memory-safe, immutable wrapper of msghdr
pub struct MsgHdr<'a> {
name: Option<&'a [u8]>,
iovs: Iovs<'a>,
control: Option<&'a [u8]>,
flags: MsgFlags,
c_self: &'a msghdr,
}
impl<'a> MsgHdr<'a> {
/// Wrap a unsafe msghdr into a safe MsgHdr
pub unsafe fn from_c(c_msg: &'a msghdr) -> Result<MsgHdr> {
// Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]>
let name_opt_slice =
new_optional_slice(c_msg.msg_name as *const u8, c_msg.msg_namelen as usize);
let iovs_opt_slice = new_optional_slice(
c_msg.msg_iov as *const libc::iovec,
c_msg.msg_iovlen as usize,
);
let control_opt_slice = new_optional_slice(
c_msg.msg_control as *const u8,
c_msg.msg_controllen as usize,
);
let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?;
let iovs = {
let iovs_vec = match iovs_opt_slice {
Some(iovs_slice) => iovs_slice
.iter()
.flat_map(|iov| new_optional_slice(iov.iov_base as *const u8, iov.iov_len))
.collect(),
None => Vec::new(),
};
Iovs::new(iovs_vec)
};
Ok(Self {
name: name_opt_slice,
iovs: iovs,
control: control_opt_slice,
flags: flags,
c_self: c_msg,
})
}
pub fn get_iovs(&self) -> &Iovs {
&self.iovs
}
pub fn get_name(&self) -> Option<&[u8]> {
self.name
}
pub fn get_control(&self) -> Option<&[u8]> {
self.control
}
pub fn get_flags(&self) -> MsgFlags {
self.flags
}
}
/// MsgHdrMut is a memory-safe, mutable wrapper of msghdr_mut
pub struct MsgHdrMut<'a> {
name: Option<&'a mut [u8]>,
iovs: IovsMut<'a>,
control: Option<&'a mut [u8]>,
flags: MsgFlags,
c_self: &'a mut msghdr_mut,
}
// TODO: use macros to eliminate redundant code between MsgHdr and MsgHdrMut
impl<'a> MsgHdrMut<'a> {
/// Wrap a unsafe msghdr_mut into a safe MsgHdrMut
pub unsafe fn from_c(c_msg: &'a mut msghdr_mut) -> Result<MsgHdrMut> {
// Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]>
let name_opt_slice =
new_optional_slice_mut(c_msg.msg_name as *mut u8, c_msg.msg_namelen as usize);
let iovs_opt_slice =
new_optional_slice_mut(c_msg.msg_iov as *mut libc::iovec, c_msg.msg_iovlen as usize);
let control_opt_slice =
new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize);
let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?;
let iovs = {
let iovs_vec = match iovs_opt_slice {
Some(iovs_slice) => iovs_slice
.iter()
.flat_map(|iov| new_optional_slice_mut(iov.iov_base as *mut u8, iov.iov_len))
.collect(),
None => Vec::new(),
};
IovsMut::new(iovs_vec)
};
Ok(Self {
name: name_opt_slice,
iovs: iovs,
control: control_opt_slice,
flags: flags,
c_self: c_msg,
})
}
/////////////////////////////////////////////////////////////////////////
// Immutable interfaces (same as MsgHdr)
/////////////////////////////////////////////////////////////////////////
pub fn get_iovs(&self) -> &IovsMut {
&self.iovs
}
pub fn get_name(&self) -> Option<&[u8]> {
self.name.as_ref().map(|name| &name[..])
}
pub fn get_control(&self) -> Option<&[u8]> {
self.control.as_ref().map(|control| &control[..])
}
pub fn get_flags(&self) -> MsgFlags {
self.flags
}
/////////////////////////////////////////////////////////////////////////
// Mutable interfaces (unique to MsgHdrMut)
/////////////////////////////////////////////////////////////////////////
pub fn get_iovs_mut<'b>(&'b mut self) -> &'b mut IovsMut<'a> {
&mut self.iovs
}
pub fn get_name_mut(&mut self) -> Option<&mut [u8]> {
self.name.as_mut().map(|name| &mut name[..])
}
pub fn get_name_max_len(&self) -> usize {
self.name.as_ref().map(|name| name.len()).unwrap_or(0)
}
pub fn set_name_len(&mut self, new_name_len: usize) -> Result<()> {
if new_name_len > self.get_name_max_len() {
return_errno!(EINVAL, "new_name_len is too big");
}
self.c_self.msg_namelen = new_name_len as libc::socklen_t;
Ok(())
}
pub fn get_control_mut(&mut self) -> Option<&mut [u8]> {
self.control.as_mut().map(|control| &mut control[..])
}
pub fn get_control_max_len(&self) -> usize {
self.control
.as_ref()
.map(|control| control.len())
.unwrap_or(0)
}
pub fn set_control_len(&mut self, new_control_len: usize) -> Result<()> {
if new_control_len > self.get_control_max_len() {
return_errno!(EINVAL, "new_control_len is too big");
}
self.c_self.msg_controllen = new_control_len;
Ok(())
}
pub fn get_name_and_control_mut(&mut self) -> (Option<&mut [u8]>, Option<&mut [u8]>) {
(
self.name.as_mut().map(|name| &mut name[..]),
self.control.as_mut().map(|control| &mut control[..]),
)
}
pub fn set_flags(&mut self, flags: MsgFlags) {
self.flags = flags;
self.c_self.msg_flags = flags.to_u32() as i32;
}
}
unsafe fn new_optional_slice<'a, T>(slice_ptr: *const T, slice_size: usize) -> Option<&'a [T]> {
if !slice_ptr.is_null() {
let slice = core::slice::from_raw_parts::<T>(slice_ptr, slice_size);
Some(slice)
} else {
None
}
}
unsafe fn new_optional_slice_mut<'a, T>(
slice_ptr: *mut T,
slice_size: usize,
) -> Option<&'a mut [T]> {
if !slice_ptr.is_null() {
let slice = core::slice::from_raw_parts_mut::<T>(slice_ptr, slice_size);
Some(slice)
} else {
None
}
}

@ -0,0 +1,17 @@
use super::*;
// TODO: use bitflag! to make this memory safe
#[derive(Debug, Copy, Clone, Default)]
pub struct MsgFlags {
bits: u32,
}
impl MsgFlags {
pub fn from_u32(c_flags: u32) -> Result<MsgFlags> {
Ok(MsgFlags { bits: 0 })
}
pub fn to_u32(&self) -> u32 {
self.bits
}
}

@ -1,16 +1,22 @@
use super::*;
mod recv;
mod send;
use fs::{File, FileRef, IoctlCmd};
use std::any::Any;
use std::io::{Read, Seek, SeekFrom, Write};
/// Native Linux socket
#[derive(Debug)]
pub struct SocketFile {
fd: c_int,
host_fd: c_int,
}
impl SocketFile {
pub fn new(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<Self> {
let ret = try_libc!(libc::ocall::socket(domain, socket_type, protocol));
Ok(SocketFile { fd: ret })
Ok(SocketFile { host_fd: ret })
}
pub fn accept(
@ -19,32 +25,28 @@ impl SocketFile {
addr_len: *mut libc::socklen_t,
flags: c_int,
) -> Result<Self> {
let ret = try_libc!(libc::ocall::accept4(self.fd, addr, addr_len, flags));
Ok(SocketFile { fd: ret })
let ret = try_libc!(libc::ocall::accept4(self.host_fd, addr, addr_len, flags));
Ok(SocketFile { host_fd: ret })
}
pub fn fd(&self) -> c_int {
self.fd
self.host_fd
}
}
impl Drop for SocketFile {
fn drop(&mut self) {
let ret = unsafe { libc::ocall::close(self.fd) };
if ret < 0 {
let errno = unsafe { libc::errno() };
warn!(
"socket (host fd: {}) close failed: errno = {}",
self.fd, errno
);
}
let ret = unsafe { libc::ocall::close(self.host_fd) };
assert!(ret == 0);
}
}
// TODO: rewrite read/write/readv/writev as send/recv
// TODO: implement readfrom/sendto
impl File for SocketFile {
fn read(&self, buf: &mut [u8]) -> Result<usize> {
let ret = try_libc!(libc::ocall::read(
self.fd,
self.host_fd,
buf.as_mut_ptr() as *mut c_void,
buf.len()
));
@ -53,7 +55,7 @@ impl File for SocketFile {
fn write(&self, buf: &[u8]) -> Result<usize> {
let ret = try_libc!(libc::ocall::write(
self.fd,
self.host_fd,
buf.as_ptr() as *const c_void,
buf.len()
));
@ -100,25 +102,6 @@ impl File for SocketFile {
return_errno!(ESPIPE, "Socket does not support seek")
}
fn metadata(&self) -> Result<Metadata> {
Ok(Metadata {
dev: 0,
inode: 0,
size: 0,
blk_size: 0,
blocks: 0,
atime: Timespec { sec: 0, nsec: 0 },
mtime: Timespec { sec: 0, nsec: 0 },
ctime: Timespec { sec: 0, nsec: 0 },
type_: FileType::Socket,
mode: 0,
nlinks: 0,
uid: 0,
gid: 0,
rdev: 0,
})
}
fn ioctl(&self, cmd: &mut IoctlCmd) -> Result<()> {
let cmd_num = cmd.cmd_num() as c_int;
let cmd_arg_ptr = cmd.arg_ptr() as *const c_int;

@ -0,0 +1,138 @@
use super::*;
impl SocketFile {
// TODO: need sockaddr type to implement send/sento
/*
pub fn recv(&self, buf: &mut [u8], flags: MsgFlags) -> Result<usize> {
let (bytes_recvd, _) = self.recvfrom(buf, flags, None)?;
Ok(bytes_recvd)
}
pub fn recvfrom(&self, buf: &mut [u8], flags: MsgFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> {
let (bytes_recvd, src_addr_len, _, _) = self.do_recvmsg(
&mut buf[..],
flags,
src_addr,
None,
)?;
Ok((bytes_recvd, src_addr_len))
}*/
pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: MsgFlags) -> Result<usize> {
// Allocate a single data buffer is big enough for all iovecs of msg.
// This is a workaround for the OCall that takes only one data buffer.
let mut data_buf = {
let data_buf_len = msg.get_iovs().total_bytes();
let data_vec = vec![0; data_buf_len];
data_vec.into_boxed_slice()
};
let (bytes_recvd, namelen_recvd, controllen_recvd, flags_recvd) = {
let data = &mut data_buf[..];
// Acquire mutable references to the name and control buffers
let (name, control) = msg.get_name_and_control_mut();
// Fill the data, the name, and the control buffers
self.do_recvmsg(data, flags, name, control)?
};
// Update the lengths and flags
msg.set_name_len(namelen_recvd)?;
msg.set_control_len(controllen_recvd)?;
msg.set_flags(flags_recvd);
let recv_data = &data_buf[..bytes_recvd];
// TODO: avoid this one extra copy due to the intermediate data buffer
msg.get_iovs_mut().scatter_copy_from(recv_data);
Ok(bytes_recvd)
}
fn do_recvmsg(
&self,
data: &mut [u8],
flags: MsgFlags,
mut name: Option<&mut [u8]>,
mut control: Option<&mut [u8]>,
) -> Result<(usize, usize, usize, MsgFlags)> {
// Prepare the arguments for OCall
// Host socket fd
let host_fd = self.host_fd;
// Name
let (msg_name, msg_namelen) = name.get_mut_ptr_and_len();
let msg_name = msg_name as *mut c_void;
let mut msg_namelen_recvd = 0_u32;
// Data
let msg_data = data.as_mut_ptr();
let msg_datalen = data.len();
// Control
let (msg_control, msg_controllen) = control.get_mut_ptr_and_len();
let msg_control = msg_control as *mut c_void;
let mut msg_controllen_recvd = 0;
// Flags
let flags = flags.to_u32() as i32;
let mut msg_flags_recvd = 0;
// Do OCall
let retval = try_libc!({
let mut retval = 0_isize;
let status = ocall_recvmsg(
&mut retval as *mut isize,
host_fd,
msg_name,
msg_namelen as u32,
&mut msg_namelen_recvd as *mut u32,
msg_data,
msg_datalen,
msg_control,
msg_controllen,
&mut msg_controllen_recvd as *mut usize,
&mut msg_flags_recvd as *mut i32,
flags,
);
assert!(status == sgx_status_t::SGX_SUCCESS);
// TODO: what if retval < 0 but buffers are modified by the
// untrusted OCall? We reset the potentially tampered buffers.
retval
});
// Check values returned from outside the enclave
let bytes_recvd = {
// Guarantted by try_libc!
debug_assert!(retval >= 0);
let retval = retval as usize;
// Check bytes_recvd returned from outside the enclave
assert!(retval <= data.len());
retval
};
let msg_namelen_recvd = msg_namelen_recvd as usize;
assert!(msg_namelen_recvd <= msg_namelen);
assert!(msg_controllen_recvd <= msg_controllen);
let flags_recvd = MsgFlags::from_u32(msg_flags_recvd as u32)?;
Ok((
bytes_recvd,
msg_namelen_recvd,
msg_controllen_recvd,
flags_recvd,
))
}
}
extern "C" {
fn ocall_recvmsg(
ret: *mut ssize_t,
fd: c_int,
msg_name: *mut c_void,
msg_namelen: libc::socklen_t,
msg_namelen_recv: *mut libc::socklen_t,
msg_data: *mut u8,
msg_data: size_t,
msg_control: *mut c_void,
msg_controllen: size_t,
msg_controllen_recv: *mut size_t,
msg_flags: *mut c_int,
flags: c_int,
) -> sgx_status_t;
}

@ -0,0 +1,84 @@
use super::*;
impl SocketFile {
// TODO: need sockaddr type to implement send/sento
/*
pub fn send(&self, buf: &[u8], flags: MsgFlags) -> Result<usize> {
self.sendto(buf, flags, None)
}
pub fn sendto(&self, buf: &[u8], flags: MsgFlags, dest_addr: Option<&[u8]>) -> Result<usize> {
Self::do_sendmsg(
self.host_fd,
&buf[..],
flags,
dest_addr,
None)
}
*/
pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: MsgFlags) -> Result<usize> {
// Copy data in iovs into a single buffer
let data_buf = msg.get_iovs().gather_to_vec();
self.do_sendmsg(&data_buf[..], flags, msg.get_name(), msg.get_control())
}
fn do_sendmsg(
&self,
data: &[u8],
flags: MsgFlags,
name: Option<&[u8]>,
control: Option<&[u8]>,
) -> Result<usize> {
let bytes_sent = try_libc!({
// Prepare the arguments for OCall
let mut retval: isize = 0;
// Host socket fd
let host_fd = self.host_fd;
// Name
let (msg_name, msg_namelen) = name.get_ptr_and_len();
let msg_name = msg_name as *const c_void;
// Data
let msg_data = data.as_ptr();
let msg_datalen = data.len();
// Control
let (msg_control, msg_controllen) = control.get_ptr_and_len();
let msg_control = msg_control as *const c_void;
// Flags
let flags = flags.to_u32() as i32;
// Do OCall
let status = ocall_sendmsg(
&mut retval as *mut isize,
host_fd,
msg_name,
msg_namelen as u32,
msg_data,
msg_datalen,
msg_control,
msg_controllen,
flags,
);
assert!(status == sgx_status_t::SGX_SUCCESS);
retval
});
debug_assert!(bytes_sent >= 0);
Ok(bytes_sent as usize)
}
}
extern "C" {
fn ocall_sendmsg(
ret: *mut ssize_t,
fd: c_int,
msg_name: *const c_void,
msg_namelen: libc::socklen_t,
msg_data: *const u8,
msg_datalen: size_t,
msg_control: *const c_void,
msg_controllen: size_t,
flags: c_int,
) -> sgx_status_t;
}

@ -0,0 +1,126 @@
use super::*;
use fs::{AsUnixSocket, File, FileDesc, FileRef, UnixSocketFile};
use process::Process;
use util::mem_util::from_user;
pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<isize> {
info!(
"sendmsg: fd: {}, msg: {:?}, flags: 0x{:x}",
fd, msg_ptr, flags_c
);
let current_ref = process::get_current();
let mut proc = current_ref.lock().unwrap();
let file_ref = proc.get_files().lock().unwrap().get(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_socket() {
let msg_c = {
from_user::check_ptr(msg_ptr)?;
let msg_c = unsafe { &*msg_ptr };
msg_c.check_member_ptrs()?;
msg_c
};
let msg = unsafe { MsgHdr::from_c(&msg_c)? };
let flags = MsgFlags::from_u32(flags_c as u32)?;
socket
.sendmsg(&msg, flags)
.map(|bytes_sent| bytes_sent as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EBADF, "does not support unix socket")
} else {
return_errno!(EBADF, "not a socket")
}
}
pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Result<isize> {
info!(
"recvmsg: fd: {}, msg: {:?}, flags: 0x{:x}",
fd, msg_mut_ptr, flags_c
);
let current_ref = process::get_current();
let mut proc = current_ref.lock().unwrap();
let file_ref = proc.get_files().lock().unwrap().get(fd as FileDesc)?;
if let Ok(socket) = file_ref.as_socket() {
let msg_mut_c = {
from_user::check_mut_ptr(msg_mut_ptr)?;
let msg_mut_c = unsafe { &mut *msg_mut_ptr };
msg_mut_c.check_member_ptrs()?;
msg_mut_c
};
let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? };
let flags = MsgFlags::from_u32(flags_c as u32)?;
socket
.recvmsg(&mut msg_mut, flags)
.map(|bytes_recvd| bytes_recvd as isize)
} else if let Ok(socket) = file_ref.as_unix_socket() {
return_errno!(EBADF, "does not support unix socket")
} else {
return_errno!(EBADF, "not a socket")
}
}
#[allow(non_camel_case_types)]
trait c_msghdr_ext {
fn check_member_ptrs(&self) -> Result<()>;
}
impl c_msghdr_ext for msghdr {
// TODO: implement this!
fn check_member_ptrs(&self) -> Result<()> {
Ok(())
}
/*
///user space check
pub unsafe fn check_from_user(user_hdr: *const msghdr) -> Result<()> {
Self::check_pointer(user_hdr, from_user::check_ptr)
}
///Check msghdr ptr
pub unsafe fn check_pointer(
user_hdr: *const msghdr,
check_ptr: fn(*const u8) -> Result<()>,
) -> Result<()> {
check_ptr(user_hdr as *const u8)?;
if (*user_hdr).msg_name.is_null() ^ ((*user_hdr).msg_namelen == 0) {
return_errno!(EINVAL, "name length is invalid");
}
if (*user_hdr).msg_iov.is_null() ^ ((*user_hdr).msg_iovlen == 0) {
return_errno!(EINVAL, "iov length is invalid");
}
if (*user_hdr).msg_control.is_null() ^ ((*user_hdr).msg_controllen == 0) {
return_errno!(EINVAL, "control length is invalid");
}
if !(*user_hdr).msg_name.is_null() {
check_ptr((*user_hdr).msg_name as *const u8)?;
}
if !(*user_hdr).msg_iov.is_null() {
check_ptr((*user_hdr).msg_iov as *const u8)?;
let iov_slice = slice::from_raw_parts((*user_hdr).msg_iov, (*user_hdr).msg_iovlen);
for iov in iov_slice {
check_ptr(iov.iov_base as *const u8)?;
}
}
if !(*user_hdr).msg_control.is_null() {
check_ptr((*user_hdr).msg_control as *const u8)?;
}
Ok(())
}
*/
}
impl c_msghdr_ext for msghdr_mut {
fn check_member_ptrs(&self) -> Result<()> {
Ok(())
}
}

@ -31,3 +31,29 @@ pub fn align_down(addr: usize, align: usize) -> usize {
pub fn unbox<T>(value: Box<T>) -> T {
*value
}
pub trait SliceOptionExt<T> {
fn get_ptr_and_len(&self) -> (*const T, usize);
}
impl<T> SliceOptionExt<T> for Option<&[T]> {
fn get_ptr_and_len(&self) -> (*const T, usize) {
match self {
Some(self_slice) => (self_slice.as_ptr(), self_slice.len()),
None => (std::ptr::null(), 0),
}
}
}
pub trait MutSliceOptionExt<T> {
fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize);
}
impl<T> MutSliceOptionExt<T> for Option<&mut [T]> {
fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize) {
match self {
Some(self_slice) => (self_slice.as_mut_ptr(), self_slice.len()),
None => (std::ptr::null_mut(), 0),
}
}
}

@ -7,8 +7,12 @@
//! 3. Dispatch the syscall to `do_*` (at this file)
//! 4. Do some memory checks then call `mod::do_*` (at each module)
use fs::*;
use fs::{
AccessFlags, AccessModes, AsUnixSocket, FcntlCmd, File, FileDesc, FileRef, IoctlCmd,
UnixSocketFile, AT_FDCWD,
};
use misc::{resource_t, rlimit_t, utsname_t};
use net::{msghdr, msghdr_mut, AsSocket, SocketFile};
use process::{pid_t, ChildProcessFilter, CloneFlags, CpuSet, FileAction, FutexFlags, FutexOp};
use std::ffi::{CStr, CString};
use std::ptr;
@ -307,6 +311,7 @@ pub extern "C" fn dispatch_syscall(
arg4 as *mut libc::sockaddr,
arg5 as *mut libc::socklen_t,
),
SYS_SOCKETPAIR => do_socketpair(
arg0 as c_int,
arg1 as c_int,
@ -314,6 +319,9 @@ pub extern "C" fn dispatch_syscall(
arg3 as *mut c_int,
),
SYS_SENDMSG => net::do_sendmsg(arg0 as c_int, arg1 as *const msghdr, arg2 as c_int),
SYS_RECVMSG => net::do_recvmsg(arg0 as c_int, arg1 as *mut msghdr_mut, arg2 as c_int),
_ => do_unknown(num, arg0, arg1, arg2, arg3, arg4, arg5),
};

@ -15,6 +15,7 @@
#include <sys/syscall.h>
#include <sys/time.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sgx_eid.h>
#include <sgx_error.h>
@ -296,6 +297,65 @@ void ocall_sync(void) {
sync();
}
ssize_t ocall_sendmsg(int sockfd,
const void *msg_name,
socklen_t msg_namelen,
const void *buf,
size_t buf_len,
const void *msg_control,
size_t msg_controllen,
int flags)
{
struct iovec msg_iov = { .iov_base = (void*)buf, .iov_len = buf_len };
struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL;
size_t msg_iovlen = buf != NULL ? 1 : 0;
struct msghdr msg = {
(void*) msg_name,
msg_namelen,
p_msg_iov,
msg_iovlen,
(void*) msg_control,
msg_controllen,
0,
};
return sendmsg(sockfd, &msg, flags);
}
ssize_t ocall_recvmsg(int sockfd,
void *msg_name,
socklen_t msg_namelen,
socklen_t* msg_namelen_recv,
void *buf,
size_t buf_len,
void *msg_control,
size_t msg_controllen,
size_t* msg_controllen_recv,
int* msg_flags_recv,
int flags)
{
struct iovec msg_iov = { .iov_base = buf, .iov_len = buf_len };
struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL;
size_t msg_iovlen = buf != NULL ? 1 : 0;
struct msghdr msg = {
msg_name,
msg_namelen,
p_msg_iov,
msg_iovlen,
msg_control,
msg_controllen,
0,
};
ssize_t ret = recvmsg(sockfd, &msg, flags);
if (ret < 0) return ret;
*msg_namelen_recv = msg.msg_namelen;
*msg_controllen_recv = msg.msg_controllen;
*msg_flags_recv = msg.msg_flags;
return ret;
}
// ==========================================================================
// Main
// ==========================================================================

@ -9,45 +9,138 @@
#include <spawn.h>
#include <unistd.h>
int main(int argc, const char *argv[]) {
const char* message = "Hello world!";
int ret;
#include "test.h"
if (argc != 3) {
printf("usage: ./client <ipaddress> <port>\n");
#define RESPONSE "ACK"
#define DEFAULT_MSG "Hello World!\n"
int connect_with_server(const char *addr_string, const char *port_string) {
//"NULL" addr means connectionless, no need to connect to server
if (strcmp(addr_string, "NULL") == 0)
return 0;
}
int ret = 0;
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
printf("create socket error: %s(errno: %d)\n", strerror(errno), errno);
return -1;
}
if (sockfd < 0)
THROW_ERROR("create socket error");
struct sockaddr_in servaddr;
memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
servaddr.sin_port = htons((uint16_t)strtol(argv[2], NULL, 10));
ret = inet_pton(AF_INET, argv[1], &servaddr.sin_addr);
servaddr.sin_port = htons((uint16_t)strtol(port_string, NULL, 10));
ret = inet_pton(AF_INET, addr_string, &servaddr.sin_addr);
if (ret <= 0) {
printf("inet_pton error for %s\n", argv[1]);
return -1;
close(sockfd);
THROW_ERROR("inet_pton error");
}
ret = connect(sockfd, (struct sockaddr *) &servaddr, sizeof(servaddr));
if (ret < 0) {
printf("connect error: %s(errno: %d)\n", strerror(errno), errno);
return -1;
}
printf("send msg to server: %s\n", message);
ret = send(sockfd, message, strlen(message), 0);
if (ret < 0) {
printf("send msg error: %s(errno: %d)\n", strerror(errno), errno);
return -1;
}
close(sockfd);
THROW_ERROR("connect error");
}
return sockfd;
}
int neogotiate_msg(int server_fd, char *buf, int buf_size) {
if (read(server_fd, buf, buf_size) < 0)
THROW_ERROR("read failed");
if (write(server_fd, RESPONSE, sizeof(RESPONSE)) < 0) {
THROW_ERROR("write failed");
}
return 0;
}
int client_send(int server_fd, char *buf) {
if (send(server_fd, buf, strlen(buf), 0) < 0)
THROW_ERROR("send msg error");
return 0;
}
int client_sendmsg(int server_fd, char *buf) {
int ret = 0;
struct msghdr msg;
struct iovec iov[1];
msg.msg_name = NULL;
msg.msg_namelen = 0;
iov[0].iov_base = buf;
iov[0].iov_len = strlen(buf);
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_control = 0;
msg.msg_controllen = 0;
msg.msg_flags = 0;
ret = sendmsg(server_fd, &msg, 0);
if (ret <= 0)
THROW_ERROR("sendmsg failed");
return ret;
}
int client_connectionless_sendmsg(char *buf) {
int ret = 0;
struct msghdr msg;
struct iovec iov[1];
struct sockaddr_in servaddr;
memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
servaddr.sin_port = htons(9900);
servaddr.sin_addr.s_addr= htonl(INADDR_ANY);
msg.msg_name = &servaddr;
msg.msg_namelen = sizeof(servaddr);
iov[0].iov_base = buf;
iov[0].iov_len = strlen(buf);
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_control = 0;
msg.msg_controllen = 0;
msg.msg_flags = 0;
int server_fd = socket(AF_INET, SOCK_DGRAM, 0);
if (server_fd < 0)
THROW_ERROR("create socket error");
ret = sendmsg(server_fd, &msg, 0);
if (ret <= 0)
THROW_ERROR("sendmsg failed");
return ret;
}
int main(int argc, const char *argv[]) {
if (argc != 3) {
THROW_ERROR("usage: ./client <ipaddress> <port>\n");
}
int ret = 0;
const int buf_size = 100;
char buf[buf_size];
int port = strtol(argv[2], NULL, 10);
int server_fd = connect_with_server(argv[1], argv[2]);
switch (port)
{
case 8800:
neogotiate_msg(server_fd, buf, buf_size);
break;
case 8801:
neogotiate_msg(server_fd, buf, buf_size);
ret = client_send(server_fd, buf);
break;
case 8802:
neogotiate_msg(server_fd, buf, buf_size);
ret = client_sendmsg(server_fd, buf);
break;
case 8803:
ret = client_connectionless_sendmsg(DEFAULT_MSG);
break;
default:
ret = client_send(server_fd, DEFAULT_MSG);
}
close(server_fd);
return ret;
}

@ -4,62 +4,259 @@
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <spawn.h>
#include <unistd.h>
int main(int argc, const char *argv[]) {
const int BUF_SIZE = 0x1000;
int ret;
#include "test.h"
int listenfd = socket(AF_INET, SOCK_STREAM, 0);
if (listenfd < 0) {
printf("create socket error: %s(errno: %d)\n", strerror(errno), errno);
return -1;
}
#define ECHO_MSG "msg for client/server test"
#define RESPONSE "ACK"
#define DEFAULT_MSG "Hello World!\n"
int connect_with_child(int port, int *child_pid) {
int ret = 0;
int listen_fd = socket(AF_INET, SOCK_STREAM, 0);
if (listen_fd < 0)
THROW_ERROR("create socket error");
int reuse = 1;
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0)
THROW_ERROR("setsockopt port to reuse failed");
struct sockaddr_in servaddr;
memset(&servaddr, 0, sizeof(servaddr));
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
servaddr.sin_port = htons(6666);
int reuse = 1;
if (setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0)
perror("setsockopt port to reuse failed");
ret = bind(listenfd, (struct sockaddr *) &servaddr, sizeof(servaddr));
servaddr.sin_port = htons(port);
ret = bind(listen_fd, (struct sockaddr *) &servaddr, sizeof(servaddr));
if (ret < 0) {
printf("bind socket error: %s(errno: %d)\n", strerror(errno), errno);
return -1;
close(listen_fd);
THROW_ERROR("bind socket failed");
}
ret = listen(listenfd, 10);
ret = listen(listen_fd, 10);
if (ret < 0) {
printf("listen socket error: %s(errno: %d)\n", strerror(errno), errno);
return -1;
close(listen_fd);
THROW_ERROR("listen socket error");
}
int client_pid;
char* client_argv[] = {"client", "127.0.0.1", "6666", NULL};
ret = posix_spawn(&client_pid, "/bin/client", NULL, NULL, client_argv, NULL);
char port_string[8];
sprintf(port_string, "%d", port);
char* client_argv[] = {"client", "127.0.0.1", port_string, NULL};
ret = posix_spawn(child_pid, "/bin/client", NULL, NULL, client_argv, NULL);
if (ret < 0) {
printf("spawn client process error: %s(errno: %d)\n", strerror(errno), errno);
return -1;
close(listen_fd);
THROW_ERROR("spawn client process error");
}
printf("====== waiting for client's request ======\n");
int connect_fd = accept(listenfd, (struct sockaddr *) NULL, NULL);
if (connect_fd < 0) {
printf("accept socket error: %s(errno: %d)", strerror(errno), errno);
return -1;
int connected_fd = accept(listen_fd, (struct sockaddr *) NULL, NULL);
if (connected_fd < 0) {
close(listen_fd);
THROW_ERROR("accept socket error");
}
char buff[BUF_SIZE];
int n = recv(connect_fd, buff, BUF_SIZE, 0);
buff[n] = '\0';
printf("recv msg from client: %s\n", buff);
close(connect_fd);
close(listenfd);
close(listen_fd);
return connected_fd;
}
int neogotiate_msg(int client_fd) {
char buf[16];
if (write(client_fd, ECHO_MSG, sizeof(ECHO_MSG)) < 0)
THROW_ERROR("write failed");
if (read(client_fd, buf, 16) < 0)
THROW_ERROR("read failed");
if (strncmp(buf, RESPONSE, sizeof(RESPONSE)) != 0) {
THROW_ERROR("msg recv mismatch");
}
return 0;
}
int server_recv(int client_fd) {
const int buf_size = 32;
char buf[buf_size];
if (recv(client_fd, buf, buf_size, 0) <= 0)
THROW_ERROR("msg recv failed");
if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) {
THROW_ERROR("msg recv mismatch");
}
return 0;
}
int server_recvmsg(int client_fd) {
int ret = 0;
const int buf_size = 1000;
char buf[buf_size];
struct msghdr msg;
struct iovec iov[1];
msg.msg_name = NULL;
msg.msg_namelen = 0;
iov[0].iov_base = buf;
iov[0].iov_len = buf_size;
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_control = 0;
msg.msg_controllen = 0;
msg.msg_flags = 0;
ret = recvmsg(client_fd, &msg, 0);
if (ret <= 0) {
THROW_ERROR("recvmsg failed");
} else {
if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) {
printf("recvmsg : %d, msg: %s\n", ret, buf);
THROW_ERROR("msg recvmsg mismatch");
}
}
return ret;
}
int server_connectionless_recvmsg() {
int ret = 0;
const int buf_size = 1000;
char buf[buf_size];
struct msghdr msg;
struct iovec iov[1];
struct sockaddr_in servaddr;
struct sockaddr_in clientaddr;
memset(&servaddr, 0, sizeof(servaddr));
memset(&clientaddr, 0, sizeof(clientaddr));
int sock = socket(AF_INET, SOCK_DGRAM, 0);
if (sock < 0)
THROW_ERROR("create socket error");
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
servaddr.sin_port = htons(9900);
ret = bind(sock, (struct sockaddr *) &servaddr, sizeof(servaddr));
if (ret < 0) {
close(sock);
THROW_ERROR("bind socket failed");
}
msg.msg_name = &clientaddr;
msg.msg_namelen = sizeof(clientaddr);
iov[0].iov_base = buf;
iov[0].iov_len = buf_size;
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_control = 0;
msg.msg_controllen = 0;
msg.msg_flags = 0;
ret = recvmsg(sock, &msg, 0);
if (ret <= 0) {
THROW_ERROR("recvmsg failed");
} else {
if (strncmp(buf, DEFAULT_MSG, sizeof(DEFAULT_MSG)) != 0) {
printf("recvmsg : %d, msg: %s\n", ret, buf);
THROW_ERROR("msg recvmsg mismatch");
} else {
inet_ntop(AF_INET, &clientaddr.sin_addr,
buf, sizeof(buf));
if(strcmp(buf, "127.0.0.1") !=0) {
printf("from port %d and address %s\n", ntohs(clientaddr.sin_port), buf);
THROW_ERROR("client addr mismatch");
}
}
}
return ret;
}
int wait_for_child_exit(int child_pid) {
int status = 0;
if (wait4(child_pid, &status, 0, NULL) < 0) {
THROW_ERROR("failed to wait4 the child process");
}
return 0;
}
int test_read_write() {
int ret = 0;
int child_pid = 0;
int client_fd = connect_with_child(8800, &child_pid);
if (client_fd < 0)
THROW_ERROR("connect failed");
else
ret = neogotiate_msg(client_fd);
//wait for the child to exit for next spawn
int status = 0;
if (wait4(child_pid, &status, 0, NULL) < 0) {
THROW_ERROR("failed to wait4 the child process");
}
return ret;
}
int test_send_recv() {
int ret = 0;
int child_pid = 0;
int client_fd = connect_with_child(8801, &child_pid);
if (client_fd < 0)
THROW_ERROR("connect failed");
if (neogotiate_msg(client_fd) < 0)
THROW_ERROR("neogotiate failed");
ret = server_recv(client_fd);
if (ret < 0) return -1;
ret = wait_for_child_exit(child_pid);
return ret;
}
int test_sendmsg_recvmsg() {
int ret = 0;
int child_pid = 0;
int client_fd = connect_with_child(8802, &child_pid);
if (client_fd < 0)
THROW_ERROR("connect failed");
if (neogotiate_msg(client_fd) < 0)
THROW_ERROR("neogotiate failed");
ret = server_recvmsg(client_fd);
if (ret < 0) return -1;
ret = wait_for_child_exit(child_pid);
return ret;
}
int test_sendmsg_recvmsg_connectionless() {
int ret = 0;
int child_pid = 0;
char* client_argv[] = {"client", "NULL", "8803", NULL};
ret = posix_spawn(&child_pid, "/bin/client", NULL, NULL, client_argv, NULL);
if (ret < 0) {
THROW_ERROR("spawn client process error");
}
ret = server_connectionless_recvmsg();
if (ret < 0) return -1;
ret = wait_for_child_exit(child_pid);
return ret;
}
static test_case_t test_cases[] = {
TEST_CASE(test_read_write),
TEST_CASE(test_send_recv),
TEST_CASE(test_sendmsg_recvmsg),
TEST_CASE(test_sendmsg_recvmsg_connectionless),
};
int main(int argc, const char* argv[]) {
return test_suite_run(test_cases, ARRAY_SIZE(test_cases));
}