diff --git a/src/Enclave.edl b/src/Enclave.edl index 2192fa8a..c1dae23b 100644 --- a/src/Enclave.edl +++ b/src/Enclave.edl @@ -40,5 +40,30 @@ enclave { [out] sgx_report_t* qe_report, [out, size=quote_buf_len] uint8_t* quote_buf, uint32_t quote_buf_len); - }; + + int64_t ocall_sendmsg( + int sockfd, + [in, size=msg_namelen] const void* msg_name, + socklen_t msg_namelen, + [in, size=buf_len] const void* buf, + size_t buf_len, + [in, size=msg_controllen] const void* msg_control, + size_t msg_controllen, + int flags + ) propagate_errno; + + int64_t ocall_recvmsg( + int sockfd, + [out, size=msg_namelen] void *msg_name, + socklen_t msg_namelen, + [out] socklen_t* msg_namelen_recv, + [out, size=buf_len] void* buf, + size_t buf_len, + [out, size=msg_controllen] void *msg_control, + size_t msg_controllen, + [out] size_t* msg_controllen_recv, + [out] int* msg_flags_recv, + int flags + ) propagate_errno; + }; }; diff --git a/src/libos/src/error/errno.rs b/src/libos/src/error/errno.rs index ffcde387..7308fc27 100644 --- a/src/libos/src/error/errno.rs +++ b/src/libos/src/error/errno.rs @@ -141,6 +141,7 @@ pub enum Errno { EHWPOISON = 133, // Note: always keep the last item in sync with ERRNO_MAX } +const ERRNO_MIN: u32 = Errno::EPERM as u32; const ERRNO_MAX: u32 = Errno::EHWPOISON as u32; impl Errno { @@ -192,10 +193,8 @@ impl Errno { } impl From for Errno { - fn from(mut raw_errno: u32) -> Self { - if raw_errno > ERRNO_MAX { - raw_errno = 0; - } + fn from(raw_errno: u32) -> Self { + assert!(ERRNO_MIN <= raw_errno && raw_errno <= ERRNO_MAX); unsafe { core::mem::transmute(raw_errno as u8) } } } diff --git a/src/libos/src/error/mod.rs b/src/libos/src/error/mod.rs index 7fa192dd..c46c5903 100644 --- a/src/libos/src/error/mod.rs +++ b/src/libos/src/error/mod.rs @@ -44,7 +44,7 @@ macro_rules! return_errno { macro_rules! try_libc { ($ret: expr) => {{ let ret = unsafe { $ret }; - if ret == -1 { + if ret < 0 { let errno = unsafe { libc::errno() }; return_errno!(Errno::from(errno as u32), "libc error"); } diff --git a/src/libos/src/fs/mod.rs b/src/libos/src/fs/mod.rs index 76c17fb4..4b83df09 100644 --- a/src/libos/src/fs/mod.rs +++ b/src/libos/src/fs/mod.rs @@ -1,5 +1,6 @@ use super::*; +use net::{AsSocket, SocketFile}; use process::Process; use rcore_fs::vfs::{FileType, FsError, INode, Metadata, Timespec}; use std::io::{Read, Seek, SeekFrom, Write}; @@ -19,7 +20,6 @@ pub use self::io_multiplexing::*; pub use self::ioctl::*; pub use self::pipe::Pipe; pub use self::root_inode::ROOT_INODE; -pub use self::socket_file::{AsSocket, SocketFile}; pub use self::unix_socket::{AsUnixSocket, UnixSocketFile}; use sgx_trts::libc::S_IWUSR; use std::any::Any; @@ -39,7 +39,6 @@ mod ioctl; mod pipe; mod root_inode; mod sgx_impl; -mod socket_file; mod unix_socket; pub fn do_open(path: &str, flags: u32, mode: u32) -> Result { diff --git a/src/libos/src/lib.rs b/src/libos/src/lib.rs index a3c2ac6a..c2a953be 100644 --- a/src/libos/src/lib.rs +++ b/src/libos/src/lib.rs @@ -57,6 +57,7 @@ mod entry; mod exception; mod fs; mod misc; +mod net; mod process; mod syscall; mod time; diff --git a/src/libos/src/net/iovs.rs b/src/libos/src/net/iovs.rs new file mode 100644 index 00000000..42b60d08 --- /dev/null +++ b/src/libos/src/net/iovs.rs @@ -0,0 +1,88 @@ +//! I/O vectors + +use super::*; + +/// A memory safe, immutable version of C iovec array +pub struct Iovs<'a> { + iovs: Vec<&'a [u8]>, +} + +impl<'a> Iovs<'a> { + pub fn new(slices: Vec<&'a [u8]>) -> Iovs { + Self { iovs: slices } + } + + pub fn as_slices(&self) -> &[&[u8]] { + &self.iovs[..] + } + + pub fn total_bytes(&self) -> usize { + self.iovs.iter().map(|s| s.len()).sum() + } + + pub fn gather_to_vec(&self) -> Vec { + Self::gather_slices_to_vec(&self.iovs[..]) + } + + fn gather_slices_to_vec(slices: &[&[u8]]) -> Vec { + let vec_len = slices.iter().map(|slice| slice.len()).sum(); + let mut vec = Vec::with_capacity(vec_len); + for slice in slices { + vec.extend_from_slice(slice); + } + vec + } +} + +/// A memory safe, mutable version of C iovec array +pub struct IovsMut<'a> { + iovs: Vec<&'a mut [u8]>, +} + +impl<'a> IovsMut<'a> { + pub fn new(slices: Vec<&'a mut [u8]>) -> Self { + Self { iovs: slices } + } + + pub fn as_slices<'b>(&'b self) -> &'b [&'a [u8]] { + let slices_mut: &'b [&'a mut [u8]] = &self.iovs[..]; + // We are "downgrading" _mutable_ slices to _immutable_ ones. It should be + // safe to do this transmute + unsafe { std::mem::transmute(slices_mut) } + } + + pub fn as_slices_mut<'b>(&'b mut self) -> &'b mut [&'a mut [u8]] { + &mut self.iovs[..] + } + + pub fn total_bytes(&self) -> usize { + self.iovs.iter().map(|s| s.len()).sum() + } + + pub fn gather_to_vec(&self) -> Vec { + Iovs::gather_slices_to_vec(self.as_slices()) + } + + pub fn scatter_copy_from(&mut self, data: &[u8]) -> usize { + let mut total_nbytes = 0; + let mut remain_slice = data; + for iov in &mut self.iovs { + if remain_slice.len() == 0 { + break; + } + + let copy_nbytes = remain_slice.len().min(iov.len()); + let dst_slice = unsafe { + debug_assert!(iov.len() >= copy_nbytes); + iov.get_unchecked_mut(..copy_nbytes) + }; + let (src_slice, _remain_slice) = remain_slice.split_at(copy_nbytes); + dst_slice.copy_from_slice(src_slice); + + remain_slice = _remain_slice; + total_nbytes += copy_nbytes; + } + debug_assert!(remain_slice.len() == 0); + total_nbytes + } +} diff --git a/src/libos/src/net/mod.rs b/src/libos/src/net/mod.rs new file mode 100644 index 00000000..e1283731 --- /dev/null +++ b/src/libos/src/net/mod.rs @@ -0,0 +1,13 @@ +use super::*; + +mod iovs; +mod msg; +mod msg_flags; +mod socket_file; +mod syscalls; + +pub use self::iovs::{Iovs, IovsMut}; +pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; +pub use self::msg_flags::MsgFlags; +pub use self::socket_file::{AsSocket, SocketFile}; +pub use self::syscalls::{do_recvmsg, do_sendmsg}; diff --git a/src/libos/src/net/msg.rs b/src/libos/src/net/msg.rs new file mode 100644 index 00000000..9007f69a --- /dev/null +++ b/src/libos/src/net/msg.rs @@ -0,0 +1,231 @@ +/// Socket message and its flags. +use super::*; + +/// C struct for a socket message with const pointers +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct msghdr { + pub msg_name: *const c_void, + pub msg_namelen: libc::socklen_t, + pub msg_iov: *const libc::iovec, + pub msg_iovlen: size_t, + pub msg_control: *const c_void, + pub msg_controllen: size_t, + pub msg_flags: c_int, +} + +/// C struct for a socket message with mutable pointers +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct msghdr_mut { + pub msg_name: *mut c_void, + pub msg_namelen: libc::socklen_t, + pub msg_iov: *mut libc::iovec, + pub msg_iovlen: size_t, + pub msg_control: *mut c_void, + pub msg_controllen: size_t, + pub msg_flags: c_int, +} + +/// MsgHdr is a memory-safe, immutable wrapper of msghdr +pub struct MsgHdr<'a> { + name: Option<&'a [u8]>, + iovs: Iovs<'a>, + control: Option<&'a [u8]>, + flags: MsgFlags, + c_self: &'a msghdr, +} + +impl<'a> MsgHdr<'a> { + /// Wrap a unsafe msghdr into a safe MsgHdr + pub unsafe fn from_c(c_msg: &'a msghdr) -> Result { + // Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]> + let name_opt_slice = + new_optional_slice(c_msg.msg_name as *const u8, c_msg.msg_namelen as usize); + let iovs_opt_slice = new_optional_slice( + c_msg.msg_iov as *const libc::iovec, + c_msg.msg_iovlen as usize, + ); + let control_opt_slice = new_optional_slice( + c_msg.msg_control as *const u8, + c_msg.msg_controllen as usize, + ); + + let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?; + + let iovs = { + let iovs_vec = match iovs_opt_slice { + Some(iovs_slice) => iovs_slice + .iter() + .flat_map(|iov| new_optional_slice(iov.iov_base as *const u8, iov.iov_len)) + .collect(), + None => Vec::new(), + }; + Iovs::new(iovs_vec) + }; + + Ok(Self { + name: name_opt_slice, + iovs: iovs, + control: control_opt_slice, + flags: flags, + c_self: c_msg, + }) + } + + pub fn get_iovs(&self) -> &Iovs { + &self.iovs + } + + pub fn get_name(&self) -> Option<&[u8]> { + self.name + } + + pub fn get_control(&self) -> Option<&[u8]> { + self.control + } + + pub fn get_flags(&self) -> MsgFlags { + self.flags + } +} + +/// MsgHdrMut is a memory-safe, mutable wrapper of msghdr_mut +pub struct MsgHdrMut<'a> { + name: Option<&'a mut [u8]>, + iovs: IovsMut<'a>, + control: Option<&'a mut [u8]>, + flags: MsgFlags, + c_self: &'a mut msghdr_mut, +} + +// TODO: use macros to eliminate redundant code between MsgHdr and MsgHdrMut +impl<'a> MsgHdrMut<'a> { + /// Wrap a unsafe msghdr_mut into a safe MsgHdrMut + pub unsafe fn from_c(c_msg: &'a mut msghdr_mut) -> Result { + // Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]> + let name_opt_slice = + new_optional_slice_mut(c_msg.msg_name as *mut u8, c_msg.msg_namelen as usize); + let iovs_opt_slice = + new_optional_slice_mut(c_msg.msg_iov as *mut libc::iovec, c_msg.msg_iovlen as usize); + let control_opt_slice = + new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize); + + let flags = MsgFlags::from_u32(c_msg.msg_flags as u32)?; + + let iovs = { + let iovs_vec = match iovs_opt_slice { + Some(iovs_slice) => iovs_slice + .iter() + .flat_map(|iov| new_optional_slice_mut(iov.iov_base as *mut u8, iov.iov_len)) + .collect(), + None => Vec::new(), + }; + IovsMut::new(iovs_vec) + }; + + Ok(Self { + name: name_opt_slice, + iovs: iovs, + control: control_opt_slice, + flags: flags, + c_self: c_msg, + }) + } + + ///////////////////////////////////////////////////////////////////////// + // Immutable interfaces (same as MsgHdr) + ///////////////////////////////////////////////////////////////////////// + + pub fn get_iovs(&self) -> &IovsMut { + &self.iovs + } + + pub fn get_name(&self) -> Option<&[u8]> { + self.name.as_ref().map(|name| &name[..]) + } + + pub fn get_control(&self) -> Option<&[u8]> { + self.control.as_ref().map(|control| &control[..]) + } + + pub fn get_flags(&self) -> MsgFlags { + self.flags + } + + ///////////////////////////////////////////////////////////////////////// + // Mutable interfaces (unique to MsgHdrMut) + ///////////////////////////////////////////////////////////////////////// + + pub fn get_iovs_mut<'b>(&'b mut self) -> &'b mut IovsMut<'a> { + &mut self.iovs + } + + pub fn get_name_mut(&mut self) -> Option<&mut [u8]> { + self.name.as_mut().map(|name| &mut name[..]) + } + + pub fn get_name_max_len(&self) -> usize { + self.name.as_ref().map(|name| name.len()).unwrap_or(0) + } + + pub fn set_name_len(&mut self, new_name_len: usize) -> Result<()> { + if new_name_len > self.get_name_max_len() { + return_errno!(EINVAL, "new_name_len is too big"); + } + self.c_self.msg_namelen = new_name_len as libc::socklen_t; + Ok(()) + } + + pub fn get_control_mut(&mut self) -> Option<&mut [u8]> { + self.control.as_mut().map(|control| &mut control[..]) + } + + pub fn get_control_max_len(&self) -> usize { + self.control + .as_ref() + .map(|control| control.len()) + .unwrap_or(0) + } + + pub fn set_control_len(&mut self, new_control_len: usize) -> Result<()> { + if new_control_len > self.get_control_max_len() { + return_errno!(EINVAL, "new_control_len is too big"); + } + self.c_self.msg_controllen = new_control_len; + Ok(()) + } + + pub fn get_name_and_control_mut(&mut self) -> (Option<&mut [u8]>, Option<&mut [u8]>) { + ( + self.name.as_mut().map(|name| &mut name[..]), + self.control.as_mut().map(|control| &mut control[..]), + ) + } + + pub fn set_flags(&mut self, flags: MsgFlags) { + self.flags = flags; + self.c_self.msg_flags = flags.to_u32() as i32; + } +} + +unsafe fn new_optional_slice<'a, T>(slice_ptr: *const T, slice_size: usize) -> Option<&'a [T]> { + if !slice_ptr.is_null() { + let slice = core::slice::from_raw_parts::(slice_ptr, slice_size); + Some(slice) + } else { + None + } +} + +unsafe fn new_optional_slice_mut<'a, T>( + slice_ptr: *mut T, + slice_size: usize, +) -> Option<&'a mut [T]> { + if !slice_ptr.is_null() { + let slice = core::slice::from_raw_parts_mut::(slice_ptr, slice_size); + Some(slice) + } else { + None + } +} diff --git a/src/libos/src/net/msg_flags.rs b/src/libos/src/net/msg_flags.rs new file mode 100644 index 00000000..096e00ec --- /dev/null +++ b/src/libos/src/net/msg_flags.rs @@ -0,0 +1,17 @@ +use super::*; + +// TODO: use bitflag! to make this memory safe +#[derive(Debug, Copy, Clone, Default)] +pub struct MsgFlags { + bits: u32, +} + +impl MsgFlags { + pub fn from_u32(c_flags: u32) -> Result { + Ok(MsgFlags { bits: 0 }) + } + + pub fn to_u32(&self) -> u32 { + self.bits + } +} diff --git a/src/libos/src/fs/socket_file.rs b/src/libos/src/net/socket_file/mod.rs similarity index 74% rename from src/libos/src/fs/socket_file.rs rename to src/libos/src/net/socket_file/mod.rs index a4c90a5e..dc89a3bc 100644 --- a/src/libos/src/fs/socket_file.rs +++ b/src/libos/src/net/socket_file/mod.rs @@ -1,16 +1,22 @@ use super::*; + +mod recv; +mod send; + +use fs::{File, FileRef, IoctlCmd}; use std::any::Any; +use std::io::{Read, Seek, SeekFrom, Write}; /// Native Linux socket #[derive(Debug)] pub struct SocketFile { - fd: c_int, + host_fd: c_int, } impl SocketFile { pub fn new(domain: c_int, socket_type: c_int, protocol: c_int) -> Result { let ret = try_libc!(libc::ocall::socket(domain, socket_type, protocol)); - Ok(SocketFile { fd: ret }) + Ok(SocketFile { host_fd: ret }) } pub fn accept( @@ -19,32 +25,28 @@ impl SocketFile { addr_len: *mut libc::socklen_t, flags: c_int, ) -> Result { - let ret = try_libc!(libc::ocall::accept4(self.fd, addr, addr_len, flags)); - Ok(SocketFile { fd: ret }) + let ret = try_libc!(libc::ocall::accept4(self.host_fd, addr, addr_len, flags)); + Ok(SocketFile { host_fd: ret }) } pub fn fd(&self) -> c_int { - self.fd + self.host_fd } } impl Drop for SocketFile { fn drop(&mut self) { - let ret = unsafe { libc::ocall::close(self.fd) }; - if ret < 0 { - let errno = unsafe { libc::errno() }; - warn!( - "socket (host fd: {}) close failed: errno = {}", - self.fd, errno - ); - } + let ret = unsafe { libc::ocall::close(self.host_fd) }; + assert!(ret == 0); } } +// TODO: rewrite read/write/readv/writev as send/recv +// TODO: implement readfrom/sendto impl File for SocketFile { fn read(&self, buf: &mut [u8]) -> Result { let ret = try_libc!(libc::ocall::read( - self.fd, + self.host_fd, buf.as_mut_ptr() as *mut c_void, buf.len() )); @@ -53,7 +55,7 @@ impl File for SocketFile { fn write(&self, buf: &[u8]) -> Result { let ret = try_libc!(libc::ocall::write( - self.fd, + self.host_fd, buf.as_ptr() as *const c_void, buf.len() )); @@ -100,25 +102,6 @@ impl File for SocketFile { return_errno!(ESPIPE, "Socket does not support seek") } - fn metadata(&self) -> Result { - Ok(Metadata { - dev: 0, - inode: 0, - size: 0, - blk_size: 0, - blocks: 0, - atime: Timespec { sec: 0, nsec: 0 }, - mtime: Timespec { sec: 0, nsec: 0 }, - ctime: Timespec { sec: 0, nsec: 0 }, - type_: FileType::Socket, - mode: 0, - nlinks: 0, - uid: 0, - gid: 0, - rdev: 0, - }) - } - fn ioctl(&self, cmd: &mut IoctlCmd) -> Result<()> { let cmd_num = cmd.cmd_num() as c_int; let cmd_arg_ptr = cmd.arg_ptr() as *const c_int; diff --git a/src/libos/src/net/socket_file/recv.rs b/src/libos/src/net/socket_file/recv.rs new file mode 100644 index 00000000..23d7e6b9 --- /dev/null +++ b/src/libos/src/net/socket_file/recv.rs @@ -0,0 +1,138 @@ +use super::*; + +impl SocketFile { + // TODO: need sockaddr type to implement send/sento + /* + pub fn recv(&self, buf: &mut [u8], flags: MsgFlags) -> Result { + let (bytes_recvd, _) = self.recvfrom(buf, flags, None)?; + Ok(bytes_recvd) + } + + pub fn recvfrom(&self, buf: &mut [u8], flags: MsgFlags, src_addr: Option<&mut [u8]>) -> Result<(usize, usize)> { + let (bytes_recvd, src_addr_len, _, _) = self.do_recvmsg( + &mut buf[..], + flags, + src_addr, + None, + )?; + Ok((bytes_recvd, src_addr_len)) + }*/ + + pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: MsgFlags) -> Result { + // Allocate a single data buffer is big enough for all iovecs of msg. + // This is a workaround for the OCall that takes only one data buffer. + let mut data_buf = { + let data_buf_len = msg.get_iovs().total_bytes(); + let data_vec = vec![0; data_buf_len]; + data_vec.into_boxed_slice() + }; + + let (bytes_recvd, namelen_recvd, controllen_recvd, flags_recvd) = { + let data = &mut data_buf[..]; + // Acquire mutable references to the name and control buffers + let (name, control) = msg.get_name_and_control_mut(); + // Fill the data, the name, and the control buffers + self.do_recvmsg(data, flags, name, control)? + }; + + // Update the lengths and flags + msg.set_name_len(namelen_recvd)?; + msg.set_control_len(controllen_recvd)?; + msg.set_flags(flags_recvd); + + let recv_data = &data_buf[..bytes_recvd]; + // TODO: avoid this one extra copy due to the intermediate data buffer + msg.get_iovs_mut().scatter_copy_from(recv_data); + + Ok(bytes_recvd) + } + + fn do_recvmsg( + &self, + data: &mut [u8], + flags: MsgFlags, + mut name: Option<&mut [u8]>, + mut control: Option<&mut [u8]>, + ) -> Result<(usize, usize, usize, MsgFlags)> { + // Prepare the arguments for OCall + // Host socket fd + let host_fd = self.host_fd; + // Name + let (msg_name, msg_namelen) = name.get_mut_ptr_and_len(); + let msg_name = msg_name as *mut c_void; + let mut msg_namelen_recvd = 0_u32; + // Data + let msg_data = data.as_mut_ptr(); + let msg_datalen = data.len(); + // Control + let (msg_control, msg_controllen) = control.get_mut_ptr_and_len(); + let msg_control = msg_control as *mut c_void; + let mut msg_controllen_recvd = 0; + // Flags + let flags = flags.to_u32() as i32; + let mut msg_flags_recvd = 0; + + // Do OCall + let retval = try_libc!({ + let mut retval = 0_isize; + let status = ocall_recvmsg( + &mut retval as *mut isize, + host_fd, + msg_name, + msg_namelen as u32, + &mut msg_namelen_recvd as *mut u32, + msg_data, + msg_datalen, + msg_control, + msg_controllen, + &mut msg_controllen_recvd as *mut usize, + &mut msg_flags_recvd as *mut i32, + flags, + ); + assert!(status == sgx_status_t::SGX_SUCCESS); + + // TODO: what if retval < 0 but buffers are modified by the + // untrusted OCall? We reset the potentially tampered buffers. + retval + }); + + // Check values returned from outside the enclave + let bytes_recvd = { + // Guarantted by try_libc! + debug_assert!(retval >= 0); + let retval = retval as usize; + + // Check bytes_recvd returned from outside the enclave + assert!(retval <= data.len()); + retval + }; + let msg_namelen_recvd = msg_namelen_recvd as usize; + assert!(msg_namelen_recvd <= msg_namelen); + assert!(msg_controllen_recvd <= msg_controllen); + let flags_recvd = MsgFlags::from_u32(msg_flags_recvd as u32)?; + + Ok(( + bytes_recvd, + msg_namelen_recvd, + msg_controllen_recvd, + flags_recvd, + )) + } +} + +extern "C" { + fn ocall_recvmsg( + ret: *mut ssize_t, + fd: c_int, + msg_name: *mut c_void, + msg_namelen: libc::socklen_t, + msg_namelen_recv: *mut libc::socklen_t, + msg_data: *mut u8, + msg_data: size_t, + msg_control: *mut c_void, + msg_controllen: size_t, + msg_controllen_recv: *mut size_t, + msg_flags: *mut c_int, + flags: c_int, + ) -> sgx_status_t; +} diff --git a/src/libos/src/net/socket_file/send.rs b/src/libos/src/net/socket_file/send.rs new file mode 100644 index 00000000..a47c4a05 --- /dev/null +++ b/src/libos/src/net/socket_file/send.rs @@ -0,0 +1,84 @@ +use super::*; + +impl SocketFile { + // TODO: need sockaddr type to implement send/sento + /* + pub fn send(&self, buf: &[u8], flags: MsgFlags) -> Result { + self.sendto(buf, flags, None) + } + + pub fn sendto(&self, buf: &[u8], flags: MsgFlags, dest_addr: Option<&[u8]>) -> Result { + Self::do_sendmsg( + self.host_fd, + &buf[..], + flags, + dest_addr, + None) + } + */ + + pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: MsgFlags) -> Result { + // Copy data in iovs into a single buffer + let data_buf = msg.get_iovs().gather_to_vec(); + + self.do_sendmsg(&data_buf[..], flags, msg.get_name(), msg.get_control()) + } + + fn do_sendmsg( + &self, + data: &[u8], + flags: MsgFlags, + name: Option<&[u8]>, + control: Option<&[u8]>, + ) -> Result { + let bytes_sent = try_libc!({ + // Prepare the arguments for OCall + let mut retval: isize = 0; + // Host socket fd + let host_fd = self.host_fd; + // Name + let (msg_name, msg_namelen) = name.get_ptr_and_len(); + let msg_name = msg_name as *const c_void; + // Data + let msg_data = data.as_ptr(); + let msg_datalen = data.len(); + // Control + let (msg_control, msg_controllen) = control.get_ptr_and_len(); + let msg_control = msg_control as *const c_void; + // Flags + let flags = flags.to_u32() as i32; + + // Do OCall + let status = ocall_sendmsg( + &mut retval as *mut isize, + host_fd, + msg_name, + msg_namelen as u32, + msg_data, + msg_datalen, + msg_control, + msg_controllen, + flags, + ); + assert!(status == sgx_status_t::SGX_SUCCESS); + + retval + }); + debug_assert!(bytes_sent >= 0); + Ok(bytes_sent as usize) + } +} + +extern "C" { + fn ocall_sendmsg( + ret: *mut ssize_t, + fd: c_int, + msg_name: *const c_void, + msg_namelen: libc::socklen_t, + msg_data: *const u8, + msg_datalen: size_t, + msg_control: *const c_void, + msg_controllen: size_t, + flags: c_int, + ) -> sgx_status_t; +} diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs new file mode 100644 index 00000000..f2255afc --- /dev/null +++ b/src/libos/src/net/syscalls.rs @@ -0,0 +1,126 @@ +use super::*; + +use fs::{AsUnixSocket, File, FileDesc, FileRef, UnixSocketFile}; +use process::Process; +use util::mem_util::from_user; + +pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result { + info!( + "sendmsg: fd: {}, msg: {:?}, flags: 0x{:x}", + fd, msg_ptr, flags_c + ); + let current_ref = process::get_current(); + let mut proc = current_ref.lock().unwrap(); + let file_ref = proc.get_files().lock().unwrap().get(fd as FileDesc)?; + + if let Ok(socket) = file_ref.as_socket() { + let msg_c = { + from_user::check_ptr(msg_ptr)?; + let msg_c = unsafe { &*msg_ptr }; + msg_c.check_member_ptrs()?; + msg_c + }; + let msg = unsafe { MsgHdr::from_c(&msg_c)? }; + + let flags = MsgFlags::from_u32(flags_c as u32)?; + + socket + .sendmsg(&msg, flags) + .map(|bytes_sent| bytes_sent as isize) + } else if let Ok(socket) = file_ref.as_unix_socket() { + return_errno!(EBADF, "does not support unix socket") + } else { + return_errno!(EBADF, "not a socket") + } +} + +pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut msghdr_mut, flags_c: c_int) -> Result { + info!( + "recvmsg: fd: {}, msg: {:?}, flags: 0x{:x}", + fd, msg_mut_ptr, flags_c + ); + let current_ref = process::get_current(); + let mut proc = current_ref.lock().unwrap(); + let file_ref = proc.get_files().lock().unwrap().get(fd as FileDesc)?; + + if let Ok(socket) = file_ref.as_socket() { + let msg_mut_c = { + from_user::check_mut_ptr(msg_mut_ptr)?; + let msg_mut_c = unsafe { &mut *msg_mut_ptr }; + msg_mut_c.check_member_ptrs()?; + msg_mut_c + }; + let mut msg_mut = unsafe { MsgHdrMut::from_c(msg_mut_c)? }; + + let flags = MsgFlags::from_u32(flags_c as u32)?; + + socket + .recvmsg(&mut msg_mut, flags) + .map(|bytes_recvd| bytes_recvd as isize) + } else if let Ok(socket) = file_ref.as_unix_socket() { + return_errno!(EBADF, "does not support unix socket") + } else { + return_errno!(EBADF, "not a socket") + } +} + +#[allow(non_camel_case_types)] +trait c_msghdr_ext { + fn check_member_ptrs(&self) -> Result<()>; +} + +impl c_msghdr_ext for msghdr { + // TODO: implement this! + fn check_member_ptrs(&self) -> Result<()> { + Ok(()) + } + /* + ///user space check + pub unsafe fn check_from_user(user_hdr: *const msghdr) -> Result<()> { + Self::check_pointer(user_hdr, from_user::check_ptr) + } + + ///Check msghdr ptr + pub unsafe fn check_pointer( + user_hdr: *const msghdr, + check_ptr: fn(*const u8) -> Result<()>, + ) -> Result<()> { + check_ptr(user_hdr as *const u8)?; + + if (*user_hdr).msg_name.is_null() ^ ((*user_hdr).msg_namelen == 0) { + return_errno!(EINVAL, "name length is invalid"); + } + + if (*user_hdr).msg_iov.is_null() ^ ((*user_hdr).msg_iovlen == 0) { + return_errno!(EINVAL, "iov length is invalid"); + } + + if (*user_hdr).msg_control.is_null() ^ ((*user_hdr).msg_controllen == 0) { + return_errno!(EINVAL, "control length is invalid"); + } + + if !(*user_hdr).msg_name.is_null() { + check_ptr((*user_hdr).msg_name as *const u8)?; + } + + if !(*user_hdr).msg_iov.is_null() { + check_ptr((*user_hdr).msg_iov as *const u8)?; + let iov_slice = slice::from_raw_parts((*user_hdr).msg_iov, (*user_hdr).msg_iovlen); + for iov in iov_slice { + check_ptr(iov.iov_base as *const u8)?; + } + } + + if !(*user_hdr).msg_control.is_null() { + check_ptr((*user_hdr).msg_control as *const u8)?; + } + Ok(()) + } + */ +} + +impl c_msghdr_ext for msghdr_mut { + fn check_member_ptrs(&self) -> Result<()> { + Ok(()) + } +} diff --git a/src/libos/src/prelude.rs b/src/libos/src/prelude.rs index d3658527..9205ff78 100644 --- a/src/libos/src/prelude.rs +++ b/src/libos/src/prelude.rs @@ -31,3 +31,29 @@ pub fn align_down(addr: usize, align: usize) -> usize { pub fn unbox(value: Box) -> T { *value } + +pub trait SliceOptionExt { + fn get_ptr_and_len(&self) -> (*const T, usize); +} + +impl SliceOptionExt for Option<&[T]> { + fn get_ptr_and_len(&self) -> (*const T, usize) { + match self { + Some(self_slice) => (self_slice.as_ptr(), self_slice.len()), + None => (std::ptr::null(), 0), + } + } +} + +pub trait MutSliceOptionExt { + fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize); +} + +impl MutSliceOptionExt for Option<&mut [T]> { + fn get_mut_ptr_and_len(&mut self) -> (*mut T, usize) { + match self { + Some(self_slice) => (self_slice.as_mut_ptr(), self_slice.len()), + None => (std::ptr::null_mut(), 0), + } + } +} diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index 6d08c41a..90a3b44d 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -7,8 +7,12 @@ //! 3. Dispatch the syscall to `do_*` (at this file) //! 4. Do some memory checks then call `mod::do_*` (at each module) -use fs::*; +use fs::{ + AccessFlags, AccessModes, AsUnixSocket, FcntlCmd, File, FileDesc, FileRef, IoctlCmd, + UnixSocketFile, AT_FDCWD, +}; use misc::{resource_t, rlimit_t, utsname_t}; +use net::{msghdr, msghdr_mut, AsSocket, SocketFile}; use process::{pid_t, ChildProcessFilter, CloneFlags, CpuSet, FileAction, FutexFlags, FutexOp}; use std::ffi::{CStr, CString}; use std::ptr; @@ -307,6 +311,7 @@ pub extern "C" fn dispatch_syscall( arg4 as *mut libc::sockaddr, arg5 as *mut libc::socklen_t, ), + SYS_SOCKETPAIR => do_socketpair( arg0 as c_int, arg1 as c_int, @@ -314,6 +319,9 @@ pub extern "C" fn dispatch_syscall( arg3 as *mut c_int, ), + SYS_SENDMSG => net::do_sendmsg(arg0 as c_int, arg1 as *const msghdr, arg2 as c_int), + SYS_RECVMSG => net::do_recvmsg(arg0 as c_int, arg1 as *mut msghdr_mut, arg2 as c_int), + _ => do_unknown(num, arg0, arg1, arg2, arg3, arg4, arg5), }; diff --git a/src/pal/pal.c b/src/pal/pal.c index 5634bcf5..c2a76cf8 100644 --- a/src/pal/pal.c +++ b/src/pal/pal.c @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -296,6 +297,65 @@ void ocall_sync(void) { sync(); } +ssize_t ocall_sendmsg(int sockfd, + const void *msg_name, + socklen_t msg_namelen, + const void *buf, + size_t buf_len, + const void *msg_control, + size_t msg_controllen, + int flags) +{ + struct iovec msg_iov = { .iov_base = (void*)buf, .iov_len = buf_len }; + struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL; + size_t msg_iovlen = buf != NULL ? 1 : 0; + + struct msghdr msg = { + (void*) msg_name, + msg_namelen, + p_msg_iov, + msg_iovlen, + (void*) msg_control, + msg_controllen, + 0, + }; + return sendmsg(sockfd, &msg, flags); +} + +ssize_t ocall_recvmsg(int sockfd, + void *msg_name, + socklen_t msg_namelen, + socklen_t* msg_namelen_recv, + void *buf, + size_t buf_len, + void *msg_control, + size_t msg_controllen, + size_t* msg_controllen_recv, + int* msg_flags_recv, + int flags) +{ + struct iovec msg_iov = { .iov_base = buf, .iov_len = buf_len }; + struct iovec* p_msg_iov = buf != NULL ? &msg_iov : NULL; + size_t msg_iovlen = buf != NULL ? 1 : 0; + + struct msghdr msg = { + msg_name, + msg_namelen, + p_msg_iov, + msg_iovlen, + msg_control, + msg_controllen, + 0, + }; + ssize_t ret = recvmsg(sockfd, &msg, flags); + if (ret < 0) return ret; + + *msg_namelen_recv = msg.msg_namelen; + *msg_controllen_recv = msg.msg_controllen; + *msg_flags_recv = msg.msg_flags; + return ret; +} + // ========================================================================== // Main // ========================================================================== diff --git a/test/client/main.c b/test/client/main.c index ab0826f7..9d5b75d1 100644 --- a/test/client/main.c +++ b/test/client/main.c @@ -9,45 +9,138 @@ #include #include -int main(int argc, const char *argv[]) { - const char* message = "Hello world!"; - int ret; +#include "test.h" - if (argc != 3) { - printf("usage: ./client \n"); - return 0; - } +#define RESPONSE "ACK" +#define DEFAULT_MSG "Hello World!\n" - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) { - printf("create socket error: %s(errno: %d)\n", strerror(errno), errno); - return -1; - } +int connect_with_server(const char *addr_string, const char *port_string) { + //"NULL" addr means connectionless, no need to connect to server + if (strcmp(addr_string, "NULL") == 0) + return 0; - struct sockaddr_in servaddr; - memset(&servaddr, 0, sizeof(servaddr)); - servaddr.sin_family = AF_INET; - servaddr.sin_port = htons((uint16_t)strtol(argv[2], NULL, 10)); + int ret = 0; + int sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd < 0) + THROW_ERROR("create socket error"); - ret = inet_pton(AF_INET, argv[1], &servaddr.sin_addr); - if (ret <= 0) { - printf("inet_pton error for %s\n", argv[1]); - return -1; - } + struct sockaddr_in servaddr; + memset(&servaddr, 0, sizeof(servaddr)); + servaddr.sin_family = AF_INET; + servaddr.sin_port = htons((uint16_t)strtol(port_string, NULL, 10)); + ret = inet_pton(AF_INET, addr_string, &servaddr.sin_addr); + if (ret <= 0) { + close(sockfd); + THROW_ERROR("inet_pton error"); + } - ret = connect(sockfd, (struct sockaddr *) &servaddr, sizeof(servaddr)); - if (ret < 0) { - printf("connect error: %s(errno: %d)\n", strerror(errno), errno); - return -1; - } + ret = connect(sockfd, (struct sockaddr *) &servaddr, sizeof(servaddr)); + if (ret < 0) { + close(sockfd); + THROW_ERROR("connect error"); + } - printf("send msg to server: %s\n", message); - ret = send(sockfd, message, strlen(message), 0); - if (ret < 0) { - printf("send msg error: %s(errno: %d)\n", strerror(errno), errno); - return -1; - } - - close(sockfd); - return 0; + return sockfd; +} + +int neogotiate_msg(int server_fd, char *buf, int buf_size) { + if (read(server_fd, buf, buf_size) < 0) + THROW_ERROR("read failed"); + + if (write(server_fd, RESPONSE, sizeof(RESPONSE)) < 0) { + THROW_ERROR("write failed"); + } + return 0; +} + +int client_send(int server_fd, char *buf) { + if (send(server_fd, buf, strlen(buf), 0) < 0) + THROW_ERROR("send msg error"); + return 0; +} + +int client_sendmsg(int server_fd, char *buf) { + int ret = 0; + struct msghdr msg; + struct iovec iov[1]; + msg.msg_name = NULL; + msg.msg_namelen = 0; + iov[0].iov_base = buf; + iov[0].iov_len = strlen(buf); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_control = 0; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + ret = sendmsg(server_fd, &msg, 0); + if (ret <= 0) + THROW_ERROR("sendmsg failed"); + return ret; +} + +int client_connectionless_sendmsg(char *buf) { + int ret = 0; + struct msghdr msg; + struct iovec iov[1]; + struct sockaddr_in servaddr; + memset(&servaddr, 0, sizeof(servaddr)); + + servaddr.sin_family = AF_INET; + servaddr.sin_port = htons(9900); + servaddr.sin_addr.s_addr= htonl(INADDR_ANY); + + msg.msg_name = &servaddr; + msg.msg_namelen = sizeof(servaddr); + iov[0].iov_base = buf; + iov[0].iov_len = strlen(buf); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_control = 0; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + int server_fd = socket(AF_INET, SOCK_DGRAM, 0); + if (server_fd < 0) + THROW_ERROR("create socket error"); + + ret = sendmsg(server_fd, &msg, 0); + if (ret <= 0) + THROW_ERROR("sendmsg failed"); + return ret; +} + +int main(int argc, const char *argv[]) { + if (argc != 3) { + THROW_ERROR("usage: ./client \n"); + } + + int ret = 0; + const int buf_size = 100; + char buf[buf_size]; + int port = strtol(argv[2], NULL, 10); + int server_fd = connect_with_server(argv[1], argv[2]); + + switch (port) + { + case 8800: + neogotiate_msg(server_fd, buf, buf_size); + break; + case 8801: + neogotiate_msg(server_fd, buf, buf_size); + ret = client_send(server_fd, buf); + break; + case 8802: + neogotiate_msg(server_fd, buf, buf_size); + ret = client_sendmsg(server_fd, buf); + break; + case 8803: + ret = client_connectionless_sendmsg(DEFAULT_MSG); + break; + default: + ret = client_send(server_fd, DEFAULT_MSG); + } + + close(server_fd); + return ret; } diff --git a/test/server/main.c b/test/server/main.c index 6bfa6ce1..89f6b950 100644 --- a/test/server/main.c +++ b/test/server/main.c @@ -4,62 +4,259 @@ #include #include #include +#include #include +#include #include #include -int main(int argc, const char *argv[]) { - const int BUF_SIZE = 0x1000; - int ret; +#include "test.h" - int listenfd = socket(AF_INET, SOCK_STREAM, 0); - if (listenfd < 0) { - printf("create socket error: %s(errno: %d)\n", strerror(errno), errno); - return -1; - } +#define ECHO_MSG "msg for client/server test" +#define RESPONSE "ACK" +#define DEFAULT_MSG "Hello World!\n" - struct sockaddr_in servaddr; - memset(&servaddr, 0, sizeof(servaddr)); - servaddr.sin_family = AF_INET; - servaddr.sin_addr.s_addr = htonl(INADDR_ANY); - servaddr.sin_port = htons(6666); +int connect_with_child(int port, int *child_pid) { + int ret = 0; + int listen_fd = socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd < 0) + THROW_ERROR("create socket error"); + int reuse = 1; + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) + THROW_ERROR("setsockopt port to reuse failed"); - int reuse = 1; - if (setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) - perror("setsockopt port to reuse failed"); + struct sockaddr_in servaddr; + memset(&servaddr, 0, sizeof(servaddr)); + servaddr.sin_family = AF_INET; + servaddr.sin_addr.s_addr = htonl(INADDR_ANY); + servaddr.sin_port = htons(port); + ret = bind(listen_fd, (struct sockaddr *) &servaddr, sizeof(servaddr)); + if (ret < 0) { + close(listen_fd); + THROW_ERROR("bind socket failed"); + } - ret = bind(listenfd, (struct sockaddr *) &servaddr, sizeof(servaddr)); - if (ret < 0) { - printf("bind socket error: %s(errno: %d)\n", strerror(errno), errno); - return -1; - } + ret = listen(listen_fd, 10); + if (ret < 0) { + close(listen_fd); + THROW_ERROR("listen socket error"); + } - ret = listen(listenfd, 10); - if (ret < 0) { - printf("listen socket error: %s(errno: %d)\n", strerror(errno), errno); - return -1; - } + char port_string[8]; + sprintf(port_string, "%d", port); + char* client_argv[] = {"client", "127.0.0.1", port_string, NULL}; + ret = posix_spawn(child_pid, "/bin/client", NULL, NULL, client_argv, NULL); + if (ret < 0) { + close(listen_fd); + THROW_ERROR("spawn client process error"); + } - int client_pid; - char* client_argv[] = {"client", "127.0.0.1", "6666", NULL}; - ret = posix_spawn(&client_pid, "/bin/client", NULL, NULL, client_argv, NULL); - if (ret < 0) { - printf("spawn client process error: %s(errno: %d)\n", strerror(errno), errno); - return -1; - } + int connected_fd = accept(listen_fd, (struct sockaddr *) NULL, NULL); + if (connected_fd < 0) { + close(listen_fd); + THROW_ERROR("accept socket error"); + } - printf("====== waiting for client's request ======\n"); - int connect_fd = accept(listenfd, (struct sockaddr *) NULL, NULL); - if (connect_fd < 0) { - printf("accept socket error: %s(errno: %d)", strerror(errno), errno); - return -1; - } - char buff[BUF_SIZE]; - int n = recv(connect_fd, buff, BUF_SIZE, 0); - buff[n] = '\0'; - printf("recv msg from client: %s\n", buff); - - close(connect_fd); - close(listenfd); - return 0; + close(listen_fd); + return connected_fd; +} + +int neogotiate_msg(int client_fd) { + char buf[16]; + if (write(client_fd, ECHO_MSG, sizeof(ECHO_MSG)) < 0) + THROW_ERROR("write failed"); + + if (read(client_fd, buf, 16) < 0) + THROW_ERROR("read failed"); + + if (strncmp(buf, RESPONSE, sizeof(RESPONSE)) != 0) { + THROW_ERROR("msg recv mismatch"); + } + return 0; +} + +int server_recv(int client_fd) { + const int buf_size = 32; + char buf[buf_size]; + + if (recv(client_fd, buf, buf_size, 0) <= 0) + THROW_ERROR("msg recv failed"); + + if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) { + THROW_ERROR("msg recv mismatch"); + } + return 0; +} + +int server_recvmsg(int client_fd) { + int ret = 0; + const int buf_size = 1000; + char buf[buf_size]; + struct msghdr msg; + struct iovec iov[1]; + + msg.msg_name = NULL; + msg.msg_namelen = 0; + iov[0].iov_base = buf; + iov[0].iov_len = buf_size; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_control = 0; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + ret = recvmsg(client_fd, &msg, 0); + if (ret <= 0) { + THROW_ERROR("recvmsg failed"); + } else { + if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) { + printf("recvmsg : %d, msg: %s\n", ret, buf); + THROW_ERROR("msg recvmsg mismatch"); + } + } + return ret; +} + +int server_connectionless_recvmsg() { + int ret = 0; + const int buf_size = 1000; + char buf[buf_size]; + struct msghdr msg; + struct iovec iov[1]; + + struct sockaddr_in servaddr; + struct sockaddr_in clientaddr; + memset(&servaddr, 0, sizeof(servaddr)); + memset(&clientaddr, 0, sizeof(clientaddr)); + + int sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock < 0) + THROW_ERROR("create socket error"); + + servaddr.sin_family = AF_INET; + servaddr.sin_addr.s_addr = htonl(INADDR_ANY); + servaddr.sin_port = htons(9900); + ret = bind(sock, (struct sockaddr *) &servaddr, sizeof(servaddr)); + if (ret < 0) { + close(sock); + THROW_ERROR("bind socket failed"); + } + + msg.msg_name = &clientaddr; + msg.msg_namelen = sizeof(clientaddr); + iov[0].iov_base = buf; + iov[0].iov_len = buf_size; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_control = 0; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + ret = recvmsg(sock, &msg, 0); + if (ret <= 0) { + THROW_ERROR("recvmsg failed"); + } else { + if (strncmp(buf, DEFAULT_MSG, sizeof(DEFAULT_MSG)) != 0) { + printf("recvmsg : %d, msg: %s\n", ret, buf); + THROW_ERROR("msg recvmsg mismatch"); + } else { + inet_ntop(AF_INET, &clientaddr.sin_addr, + buf, sizeof(buf)); + if(strcmp(buf, "127.0.0.1") !=0) { + printf("from port %d and address %s\n", ntohs(clientaddr.sin_port), buf); + THROW_ERROR("client addr mismatch"); + } + } + } + return ret; +} + +int wait_for_child_exit(int child_pid) { + int status = 0; + if (wait4(child_pid, &status, 0, NULL) < 0) { + THROW_ERROR("failed to wait4 the child process"); + } + return 0; +} + +int test_read_write() { + int ret = 0; + int child_pid = 0; + int client_fd = connect_with_child(8800, &child_pid); + if (client_fd < 0) + THROW_ERROR("connect failed"); + else + ret = neogotiate_msg(client_fd); + + //wait for the child to exit for next spawn + int status = 0; + if (wait4(child_pid, &status, 0, NULL) < 0) { + THROW_ERROR("failed to wait4 the child process"); + } + + return ret; +} + +int test_send_recv() { + int ret = 0; + int child_pid = 0; + int client_fd = connect_with_child(8801, &child_pid); + if (client_fd < 0) + THROW_ERROR("connect failed"); + + if (neogotiate_msg(client_fd) < 0) + THROW_ERROR("neogotiate failed"); + + ret = server_recv(client_fd); + if (ret < 0) return -1; + + ret = wait_for_child_exit(child_pid); + + return ret; +} + +int test_sendmsg_recvmsg() { + int ret = 0; + int child_pid = 0; + int client_fd = connect_with_child(8802, &child_pid); + if (client_fd < 0) + THROW_ERROR("connect failed"); + + if (neogotiate_msg(client_fd) < 0) + THROW_ERROR("neogotiate failed"); + + ret = server_recvmsg(client_fd); + if (ret < 0) return -1; + + ret = wait_for_child_exit(child_pid); + + return ret; +} + +int test_sendmsg_recvmsg_connectionless() { + int ret = 0; + int child_pid = 0; + + char* client_argv[] = {"client", "NULL", "8803", NULL}; + ret = posix_spawn(&child_pid, "/bin/client", NULL, NULL, client_argv, NULL); + if (ret < 0) { + THROW_ERROR("spawn client process error"); + } + + ret = server_connectionless_recvmsg(); + if (ret < 0) return -1; + + ret = wait_for_child_exit(child_pid); + + return ret; +} +static test_case_t test_cases[] = { + TEST_CASE(test_read_write), + TEST_CASE(test_send_recv), + TEST_CASE(test_sendmsg_recvmsg), + TEST_CASE(test_sendmsg_recvmsg_connectionless), +}; + +int main(int argc, const char* argv[]) { + return test_suite_run(test_cases, ARRAY_SIZE(test_cases)); }