diff --git a/src/libos/src/fs/events.rs b/src/libos/src/fs/events.rs index 3f2afb6f..14b19e33 100644 --- a/src/libos/src/fs/events.rs +++ b/src/libos/src/fs/events.rs @@ -14,6 +14,9 @@ bitflags! { const HUP = 0x0010; // = POLLHUP const NVAL = 0x0020; // = POLLNVAL const RDHUP = 0x2000; // = POLLRDHUP + + /// Events that are always polled even without specifying them. + const ALWAYS_POLL = Self::ERR.bits | Self::HUP.bits; } } diff --git a/src/libos/src/fs/file_ops/file_flags.rs b/src/libos/src/fs/file_ops/file_flags.rs index b06f1545..2ceb8c94 100644 --- a/src/libos/src/fs/file_ops/file_flags.rs +++ b/src/libos/src/fs/file_ops/file_flags.rs @@ -129,4 +129,8 @@ impl StatusFlags { pub fn is_fast_open(&self) -> bool { self.contains(StatusFlags::O_PATH) } + + pub fn is_nonblocking(&self) -> bool { + self.contains(StatusFlags::O_NONBLOCK) + } } diff --git a/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs b/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs index 66d89a1d..c507a248 100644 --- a/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs +++ b/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs @@ -99,7 +99,7 @@ fn get_ifconf_by_host(fd: FileDesc, if_conf: &mut IfConf) -> Result<()> { // len: the size of the buf // recv_len: accepts transferred data length when buf is used to get data from host // - fn socket_ocall_ioctl_repack( + fn occlum_ocall_ioctl_repack( ret: *mut i32, fd: i32, cmd_num: i32, @@ -112,7 +112,7 @@ fn get_ifconf_by_host(fd: FileDesc, if_conf: &mut IfConf) -> Result<()> { try_libc!({ let mut recv_len: i32 = 0; let mut retval: i32 = 0; - let status = socket_ocall_ioctl_repack( + let status = occlum_ocall_ioctl_repack( &mut retval as *mut i32, fd as _, SIOCGIFCONF as _, diff --git a/src/libos/src/net/io_multiplexing/poll_new/event_monitor.rs b/src/libos/src/net/io_multiplexing/poll_new/event_monitor.rs index be06dbdc..4d10de1b 100644 --- a/src/libos/src/net/io_multiplexing/poll_new/event_monitor.rs +++ b/src/libos/src/net/io_multiplexing/poll_new/event_monitor.rs @@ -1,10 +1,9 @@ -use std::cell::Cell; use std::ptr; use std::sync::Weak; use std::time::Duration; use crate::events::{Observer, Waiter, WaiterQueueObserver}; -use crate::fs::{AtomicIoEvents, IoEvents}; +use crate::fs::IoEvents; use crate::prelude::*; use crate::time::{timespec_t, TIMERSLACK}; @@ -265,13 +264,13 @@ trait ObserverExt { impl ObserverExt for Weak> { fn register_files<'a>(&self, files_and_events: impl Iterator) { for (file, events) in files_and_events { - let notifier = match file.notifier() { + match file.notifier() { None => continue, - Some(notifier) => notifier, + Some(notifier) => { + let mask = *events; + notifier.register(self.clone(), Some(mask), None); + } }; - - let mask = *events; - notifier.register(self.clone(), Some(mask), None); } } @@ -280,11 +279,12 @@ impl ObserverExt for Weak> { files_and_events: impl Iterator, ) { for (file, events) in files_and_events { - let notifier = match file.notifier() { + match file.notifier() { None => continue, - Some(notifier) => notifier, + Some(notifier) => { + notifier.unregister(self); + } }; - notifier.unregister(self); } } } diff --git a/src/libos/src/net/socket/flags.rs b/src/libos/src/net/socket/flags.rs deleted file mode 100644 index f2ad14eb..00000000 --- a/src/libos/src/net/socket/flags.rs +++ /dev/null @@ -1,43 +0,0 @@ -use super::*; - -bitflags! { - pub struct SendFlags: i32 { - const MSG_OOB = 0x01; - const MSG_DONTROUTE = 0x04; - const MSG_DONTWAIT = 0x40; // Nonblocking io - const MSG_EOR = 0x80; // End of record - const MSG_CONFIRM = 0x0800; // Confirm path validity - const MSG_NOSIGNAL = 0x4000; // Do not generate SIGPIPE - const MSG_MORE = 0x8000; // Sender will send more - } -} - -bitflags! { - pub struct RecvFlags: i32 { - const MSG_OOB = 0x01; - const MSG_PEEK = 0x02; - const MSG_TRUNC = 0x20; - const MSG_DONTWAIT = 0x40; // Nonblocking io - const MSG_WAITALL = 0x0100; // Wait for a full request - const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue - const MSG_CMSG_CLOEXEC = 0x40000000; // Set close_on_exec for file descriptor received through M_RIGHTS - } -} - -bitflags! { - pub struct MsgHdrFlags: i32 { - const MSG_OOB = 0x01; - const MSG_CTRUNC = 0x08; - const MSG_TRUNC = 0x20; - const MSG_EOR = 0x80; // End of record - const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue - const MSG_NOTIFICATION = 0x8000; // Only applicable to SCTP socket - } -} - -bitflags! { - pub struct FileFlags: i32 { - const SOCK_NONBLOCK = 0x800; - const SOCK_CLOEXEC = 0x80000; - } -} diff --git a/src/libos/src/net/socket/host/ioctl_impl.rs b/src/libos/src/net/socket/host/ioctl_impl.rs deleted file mode 100644 index b551353d..00000000 --- a/src/libos/src/net/socket/host/ioctl_impl.rs +++ /dev/null @@ -1,82 +0,0 @@ -use super::*; -use fs::{occlum_ocall_ioctl, BuiltinIoctlNum, IfConf, IoctlCmd}; - -impl HostSocket { - pub(super) fn ioctl_impl(&self, cmd: &mut IoctlCmd) -> Result { - if let IoctlCmd::SIOCGIFCONF(arg_ref) = cmd { - return self.ioctl_getifconf(arg_ref); - } - - let cmd_num = cmd.cmd_num() as c_int; - let cmd_arg_ptr = cmd.arg_ptr() as *mut c_void; - let ret = try_libc!({ - let mut retval: i32 = 0; - let status = occlum_ocall_ioctl( - &mut retval as *mut i32, - self.raw_host_fd() as i32, - cmd_num, - cmd_arg_ptr, - cmd.arg_len(), - ); - assert!(status == sgx_status_t::SGX_SUCCESS); - retval - }); - // FIXME: add sanity checks for results returned for socket-related ioctls - cmd.validate_arg_and_ret_vals(ret)?; - Ok(ret) - } - - fn ioctl_getifconf(&self, arg_ref: &mut IfConf) -> Result { - if !arg_ref.ifc_buf.is_null() && arg_ref.ifc_len == 0 { - return Ok(0); - } - - let ret = try_libc!({ - let mut recv_len: i32 = 0; - let mut retval: i32 = 0; - let status = occlum_ocall_ioctl_repack( - &mut retval as *mut i32, - self.raw_host_fd() as i32, - BuiltinIoctlNum::SIOCGIFCONF as i32, - arg_ref.ifc_buf, - arg_ref.ifc_len, - &mut recv_len as *mut i32, - ); - assert!(status == sgx_status_t::SGX_SUCCESS); - // If ifc_req is NULL, SIOCGIFCONF returns the necessary buffer - // size in bytes for receiving all available addresses in ifc_len - // which is irrelevant to the orginal ifc_len. - if !arg_ref.ifc_buf.is_null() { - assert!(arg_ref.ifc_len >= recv_len); - } - - arg_ref.ifc_len = recv_len; - retval - }); - Ok(ret) - } -} - -extern "C" { - // Used to ioctl arguments with pointer members. - // - // Before the call the area the pointers points to should be assembled into - // one continous memory block. Then the block is repacked to ioctl arguments - // in the ocall implementation in host. - // - // ret: holds the return value of ioctl in host - // fd: the host fd for the device - // cmd_num: request number of the ioctl - // buf: the data to exchange with host - // len: the size of the buf - // recv_len: accepts transferred data length when buf is used to get data from host - // - fn occlum_ocall_ioctl_repack( - ret: *mut i32, - fd: c_int, - cmd_num: c_int, - buf: *const u8, - len: i32, - recv_len: *mut i32, - ) -> sgx_status_t; -} diff --git a/src/libos/src/net/socket/host/mod.rs b/src/libos/src/net/socket/host/mod.rs index 974ced50..24a3e277 100644 --- a/src/libos/src/net/socket/host/mod.rs +++ b/src/libos/src/net/socket/host/mod.rs @@ -1,18 +1,16 @@ use std::any::Any; use std::io::{Read, Seek, SeekFrom, Write}; use std::mem; - use atomic::Atomic; use super::*; use crate::fs::{ occlum_ocall_ioctl, AccessMode, CreationFlags, File, FileRef, HostFd, IoEvents, IoNotifier, - IoctlCmd, StatusFlags, + IoctlRawCmd, StatusFlags, }; use crate::process::IO_BUF_SIZE; -mod ioctl_impl; mod recv; mod send; mod socket_file; @@ -27,14 +25,14 @@ pub struct HostSocket { impl HostSocket { pub fn new( - domain: AddressFamily, - socket_type: SocketType, - file_flags: FileFlags, + domain: Domain, + socket_type: Type, + socket_flags: SocketFlags, protocol: i32, ) -> Result { let raw_host_fd = try_libc!(libc::ocall::socket( domain as i32, - socket_type as i32 | file_flags.bits(), + socket_type as i32 | socket_flags.bits(), protocol )) as FileDesc; let host_fd = HostFd::new(raw_host_fd); @@ -51,7 +49,7 @@ impl HostSocket { }) } - pub fn bind(&self, addr: &SockAddr) -> Result<()> { + pub fn bind(&self, addr: &RawAddr) -> Result<()> { let (addr_ptr, addr_len) = addr.as_ptr_and_len(); let ret = try_libc!(libc::ocall::bind( @@ -67,8 +65,8 @@ impl HostSocket { Ok(()) } - pub fn accept(&self, flags: FileFlags) -> Result<(Self, Option)> { - let mut sockaddr = SockAddr::default(); + pub fn accept(&self, flags: SocketFlags) -> Result<(Self, Option)> { + let mut sockaddr = RawAddr::default(); let mut addr_len = sockaddr.len(); let raw_host_fd = try_libc!(libc::ocall::accept4( @@ -88,7 +86,33 @@ impl HostSocket { Ok((HostSocket::from_host_fd(host_fd)?, addr_option)) } - pub fn connect(&self, addr: &Option) -> Result<()> { + pub fn addr(&self) -> Result { + let mut sockaddr = RawAddr::default(); + let mut addr_len = sockaddr.len(); + try_libc!(libc::ocall::getsockname( + self.raw_host_fd() as i32, + sockaddr.as_mut_ptr() as *mut _, + &mut addr_len as *mut _ as *mut _, + )); + + sockaddr.set_len(addr_len)?; + Ok(sockaddr) + } + + pub fn peer_addr(&self) -> Result { + let mut sockaddr = RawAddr::default(); + let mut addr_len = sockaddr.len(); + try_libc!(libc::ocall::getpeername( + self.raw_host_fd() as i32, + sockaddr.as_mut_ptr() as *mut _, + &mut addr_len as *mut _ as *mut _, + )); + + sockaddr.set_len(addr_len)?; + Ok(sockaddr) + } + + pub fn connect(&self, addr: &Option) -> Result<()> { debug!("connect: host_fd: {}, addr {:?}", self.raw_host_fd(), addr); let (addr_ptr, addr_len) = if let Some(sock_addr) = addr { @@ -109,34 +133,29 @@ impl HostSocket { &self, buf: &[u8], flags: SendFlags, - addr_option: &Option, + addr_option: &Option, ) -> Result { let bufs = vec![buf]; - let name_option = addr_option.as_ref().map(|addr| addr.as_slice()); - self.do_sendmsg(&bufs, flags, name_option, None) + self.sendmsg(&bufs, flags, addr_option, None) } - pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { - let mut sockaddr = SockAddr::default(); + pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { + let mut sockaddr = RawAddr::default(); let mut bufs = vec![buf]; - let (bytes_recv, addr_len, _, _) = - self.do_recvmsg(&mut bufs, flags, Some(sockaddr.as_mut_slice()), None)?; + let (bytes_recv, recv_addr, _, _) = self.recvmsg(&mut bufs, flags, None)?; - let addr_option = if addr_len != 0 { - sockaddr.set_len(addr_len)?; - Some(sockaddr) - } else { - None - }; - Ok((bytes_recv, addr_option)) + Ok((bytes_recv, recv_addr)) } pub fn raw_host_fd(&self) -> FileDesc { self.host_fd.to_raw() } - pub fn shutdown(&self, how: HowToShut) -> Result<()> { - try_libc!(libc::ocall::shutdown(self.raw_host_fd() as i32, how.bits())); + pub fn shutdown(&self, how: Shutdown) -> Result<()> { + try_libc!(libc::ocall::shutdown( + self.raw_host_fd() as i32, + how.to_c() as i32 + )); Ok(()) } } diff --git a/src/libos/src/net/socket/host/recv.rs b/src/libos/src/net/socket/host/recv.rs index e8084694..edacdb37 100644 --- a/src/libos/src/net/socket/host/recv.rs +++ b/src/libos/src/net/socket/host/recv.rs @@ -7,30 +7,12 @@ impl HostSocket { Ok(bytes_recvd) } - pub fn recvmsg<'a, 'b>(&self, msg: &'b mut MsgHdrMut<'a>, flags: RecvFlags) -> Result { - // Do OCall-based recvmsg - let (bytes_recvd, namelen_recvd, controllen_recvd, flags_recvd) = { - // Acquire mutable references to the name and control buffers - let (iovs, name, control) = msg.get_iovs_name_and_control_mut(); - // Fill the data, the name, and the control buffers - self.do_recvmsg(iovs.as_slices_mut(), flags, name, control)? - }; - - // Update the output lengths and flags - msg.set_name_len(namelen_recvd)?; - msg.set_control_len(controllen_recvd)?; - msg.set_flags(flags_recvd); - - Ok(bytes_recvd) - } - - pub(super) fn do_recvmsg( + pub fn recvmsg( &self, data: &mut [&mut [u8]], flags: RecvFlags, - mut name: Option<&mut [u8]>, - mut control: Option<&mut [u8]>, - ) -> Result<(usize, usize, usize, MsgHdrFlags)> { + control: Option<&mut [u8]>, + ) -> Result<(usize, Option, MsgFlags, usize)> { let current = current!(); let data_length = data.iter().map(|s| s.len()).sum(); let mut ocall_alloc; @@ -52,7 +34,7 @@ impl HostSocket { } bufs }; - let retval = self.do_recvmsg_untrusted_data(&mut u_data, flags, name, control)?; + let retval = self.do_recvmsg_untrusted_data(&mut u_data, flags, control)?; let mut remain = retval.0; for (i, buf) in data.iter_mut().enumerate() { @@ -71,16 +53,15 @@ impl HostSocket { &self, data: &mut [UntrustedSlice], flags: RecvFlags, - mut name: Option<&mut [u8]>, mut control: Option<&mut [u8]>, - ) -> Result<(usize, usize, usize, MsgHdrFlags)> { + ) -> Result<(usize, Option, MsgFlags, usize)> { // Prepare the arguments for OCall - // Host socket fd let host_fd = self.raw_host_fd() as i32; - // Name - let (msg_name, msg_namelen) = name.as_mut_ptr_and_len(); - let msg_name = msg_name as *mut c_void; + let mut addr = RawAddr::default(); + let mut msg_name = addr.as_mut_ptr(); + let mut msg_namelen = addr.len(); let mut msg_namelen_recvd = 0_u32; + // Iovs let mut raw_iovs: Vec = data .iter() @@ -101,7 +82,7 @@ impl HostSocket { let status = occlum_ocall_recvmsg( &mut retval as *mut isize, host_fd, - msg_name, + msg_name as _, msg_namelen as u32, &mut msg_namelen_recvd as *mut u32, msg_iov, @@ -119,7 +100,7 @@ impl HostSocket { retval }); - let flags_recvd = MsgHdrFlags::from_bits(msg_flags_recvd).unwrap(); + let flags_recvd = MsgFlags::from_bits(msg_flags_recvd).unwrap(); // Check values returned from outside the enclave let bytes_recvd = { @@ -133,21 +114,24 @@ impl HostSocket { // For MSG_TRUNC recvmsg returns the real length of the packet or datagram, // even when it was longer than the passed buffer. if flags.contains(RecvFlags::MSG_TRUNC) && retval > max_bytes_recvd { - assert!(flags_recvd.contains(MsgHdrFlags::MSG_TRUNC)); + assert!(flags_recvd.contains(MsgFlags::MSG_TRUNC)); } else { assert!(retval <= max_bytes_recvd); } retval }; let msg_namelen_recvd = msg_namelen_recvd as usize; + + let raw_addr = if msg_namelen_recvd == 0 { + None + } else { + addr.set_len(msg_namelen_recvd)?; + Some(addr) + }; + assert!(msg_namelen_recvd <= msg_namelen); assert!(msg_controllen_recvd <= msg_controllen); - Ok(( - bytes_recvd, - msg_namelen_recvd, - msg_controllen_recvd, - flags_recvd, - )) + Ok((bytes_recvd, raw_addr, flags_recvd, msg_controllen_recvd)) } } diff --git a/src/libos/src/net/socket/host/send.rs b/src/libos/src/net/socket/host/send.rs index 483a89a2..0df6f73c 100644 --- a/src/libos/src/net/socket/host/send.rs +++ b/src/libos/src/net/socket/host/send.rs @@ -5,22 +5,11 @@ impl HostSocket { self.sendto(buf, flags, &None) } - pub fn sendmsg<'a, 'b>(&self, msg: &'b MsgHdr<'a>, flags: SendFlags) -> Result { - let msg_iov = msg.get_iovs(); - - self.do_sendmsg( - msg_iov.as_slices(), - flags, - msg.get_name(), - msg.get_control(), - ) - } - - pub(super) fn do_sendmsg( + pub fn sendmsg( &self, data: &[&[u8]], flags: SendFlags, - name: Option<&[u8]>, + addr: &Option, control: Option<&[u8]>, ) -> Result { let current = current!(); @@ -45,8 +34,8 @@ impl HostSocket { bufs }; - let retval = self.do_sendmsg_untrusted_data(&u_data, flags, name, control); - retval + let name = addr.as_ref().map(|raw_addr| raw_addr.as_slice()); + self.do_sendmsg_untrusted_data(&u_data, flags, name, control) } fn do_sendmsg_untrusted_data( diff --git a/src/libos/src/net/socket/host/socket_file.rs b/src/libos/src/net/socket/host/socket_file.rs index 5bd8535a..f259b5cd 100644 --- a/src/libos/src/net/socket/host/socket_file.rs +++ b/src/libos/src/net/socket/host/socket_file.rs @@ -4,10 +4,11 @@ use std::io::{Read, Seek, SeekFrom, Write}; use atomic::{Atomic, Ordering}; use super::*; +use crate::fs::{AccessMode, File, HostFd, IoEvents, StatusFlags, STATUS_FLAGS_MASK}; use crate::fs::{ - occlum_ocall_ioctl, AccessMode, AtomicIoEvents, CreationFlags, File, FileRef, HostFd, IoEvents, - IoctlCmd, StatusFlags, STATUS_FLAGS_MASK, + GetIfConf, GetIfReqWithRawCmd, GetReadBufLen, IoctlCmd, NonBuiltinIoctlCmd, SetNonBlocking, }; +use crate::net::socket::sockopt::{GetSockOptRawCmd, SetSockOptRawCmd}; //TODO: refactor write syscall to allow zero length with non-zero buffer impl File for HostSocket { @@ -34,20 +35,46 @@ impl File for HostSocket { } fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { - let (bytes_recvd, _, _, _) = self.do_recvmsg(bufs, RecvFlags::empty(), None, None)?; + let (bytes_recvd, _, _, _) = self.recvmsg(bufs, RecvFlags::empty(), None)?; Ok(bytes_recvd) } fn writev(&self, bufs: &[&[u8]]) -> Result { - self.do_sendmsg(bufs, SendFlags::empty(), None, None) + self.sendmsg(bufs, SendFlags::empty(), &None, None) } fn seek(&self, pos: SeekFrom) -> Result { return_errno!(ESPIPE, "Socket does not support seek") } - fn ioctl(&self, cmd: &mut IoctlCmd) -> Result { - self.ioctl_impl(cmd) + fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { + match_ioctl_cmd_mut!(&mut *cmd, { + cmd: GetSockOptRawCmd => { + cmd.execute(self.raw_host_fd())?; + }, + cmd: SetSockOptRawCmd => { + cmd.execute(self.raw_host_fd())?; + }, + cmd: GetIfReqWithRawCmd => { + cmd.execute(self.raw_host_fd())?; + }, + cmd: GetIfConf => { + cmd.execute(self.raw_host_fd())?; + }, + cmd: GetReadBufLen => { + cmd.execute(self.raw_host_fd())?; + }, + cmd: SetNonBlocking => { + cmd.execute(self.raw_host_fd())?; + }, + cmd: NonBuiltinIoctlCmd => { + cmd.execute(self.raw_host_fd())?; + }, + _ => { + return_errno!(EINVAL, "Not supported yet"); + } + }); + Ok(()) } fn access_mode(&self) -> Result { diff --git a/src/libos/src/net/socket/msg.rs b/src/libos/src/net/socket/msg.rs deleted file mode 100644 index 47cb7a6f..00000000 --- a/src/libos/src/net/socket/msg.rs +++ /dev/null @@ -1,354 +0,0 @@ -/// Socket message and its flags. -use super::*; - -/// C struct for a socket message with const pointers -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct msghdr { - pub msg_name: *const c_void, - pub msg_namelen: libc::socklen_t, - pub msg_iov: *const libc::iovec, - pub msg_iovlen: size_t, - pub msg_control: *const c_void, - pub msg_controllen: size_t, - pub msg_flags: c_int, -} - -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct mmsghdr { - pub msg_hdr: msghdr, - pub msg_len: c_uint, -} - -/// C struct for a socket message with mutable pointers -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct msghdr_mut { - pub msg_name: *mut c_void, - pub msg_namelen: libc::socklen_t, - pub msg_iov: *mut libc::iovec, - pub msg_iovlen: size_t, - pub msg_control: *mut c_void, - pub msg_controllen: size_t, - pub msg_flags: c_int, -} - -/// MsgHdr is a memory-safe, immutable wrapper of msghdr -pub struct MsgHdr<'a> { - name: Option<&'a [u8]>, - iovs: Iovs<'a>, - control: Option<&'a [u8]>, - flags: MsgHdrFlags, - c_self: &'a msghdr, -} - -impl<'a> MsgHdr<'a> { - /// Wrap a unsafe msghdr into a safe MsgHdr - pub unsafe fn from_c(c_msg: &'a msghdr) -> Result { - // Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]> - let name_opt_slice = - new_optional_slice(c_msg.msg_name as *const u8, c_msg.msg_namelen as usize); - let iovs_opt_slice = new_optional_slice( - c_msg.msg_iov as *const libc::iovec, - c_msg.msg_iovlen as usize, - ); - let control_opt_slice = new_optional_slice( - c_msg.msg_control as *const u8, - c_msg.msg_controllen as usize, - ); - - let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags); - - let iovs = { - let iovs_vec = match iovs_opt_slice { - Some(iovs_slice) => iovs_slice - .iter() - .flat_map(|iov| new_optional_slice(iov.iov_base as *const u8, iov.iov_len)) - .collect(), - None => Vec::new(), - }; - Iovs::new(iovs_vec) - }; - - Ok(Self { - name: name_opt_slice, - iovs: iovs, - control: control_opt_slice, - flags: flags, - c_self: c_msg, - }) - } - - pub fn get_iovs(&self) -> &Iovs { - &self.iovs - } - - pub fn get_name(&self) -> Option<&[u8]> { - self.name - } - - pub fn get_control(&self) -> Option<&[u8]> { - self.control - } - - pub fn get_flags(&self) -> MsgHdrFlags { - self.flags - } -} - -/// MsgHdrMut is a memory-safe, mutable wrapper of msghdr_mut -pub struct MsgHdrMut<'a> { - name: Option<&'a mut [u8]>, - iovs: IovsMut<'a>, - control: Option<&'a mut [u8]>, - flags: MsgHdrFlags, - c_self: &'a mut msghdr_mut, -} - -// TODO: use macros to eliminate redundant code between MsgHdr and MsgHdrMut -impl<'a> MsgHdrMut<'a> { - /// Wrap a unsafe msghdr_mut into a safe MsgHdrMut - pub unsafe fn from_c(c_msg: &'a mut msghdr_mut) -> Result { - // Convert c_msg's (*mut T, usize)-pair fields to Option<&mut [T]> - let name_opt_slice = - new_optional_slice_mut(c_msg.msg_name as *mut u8, c_msg.msg_namelen as usize); - let iovs_opt_slice = - new_optional_slice_mut(c_msg.msg_iov as *mut libc::iovec, c_msg.msg_iovlen as usize); - let control_opt_slice = - new_optional_slice_mut(c_msg.msg_control as *mut u8, c_msg.msg_controllen as usize); - - let flags = MsgHdrFlags::from_bits_truncate(c_msg.msg_flags); - - let iovs = { - let iovs_vec = match iovs_opt_slice { - Some(iovs_slice) => iovs_slice - .iter() - .flat_map(|iov| new_optional_slice_mut(iov.iov_base as *mut u8, iov.iov_len)) - .collect(), - None => Vec::new(), - }; - IovsMut::new(iovs_vec) - }; - - Ok(Self { - name: name_opt_slice, - iovs: iovs, - control: control_opt_slice, - flags: flags, - c_self: c_msg, - }) - } - - ///////////////////////////////////////////////////////////////////////// - // Immutable interfaces (same as MsgHdr) - ///////////////////////////////////////////////////////////////////////// - - pub fn get_iovs(&self) -> &IovsMut { - &self.iovs - } - - pub fn get_name(&self) -> Option<&[u8]> { - self.name.as_ref().map(|name| &name[..]) - } - - pub fn get_control(&self) -> Option<&[u8]> { - self.control.as_ref().map(|control| &control[..]) - } - - pub fn get_flags(&self) -> MsgHdrFlags { - self.flags - } - - ///////////////////////////////////////////////////////////////////////// - // Mutable interfaces (unique to MsgHdrMut) - ///////////////////////////////////////////////////////////////////////// - - pub fn get_iovs_mut<'b>(&'b mut self) -> &'b mut IovsMut<'a> { - &mut self.iovs - } - - pub fn get_name_mut(&mut self) -> Option<&mut [u8]> { - self.name.as_mut().map(|name| &mut name[..]) - } - - pub fn get_name_max_len(&self) -> usize { - self.name.as_ref().map(|name| name.len()).unwrap_or(0) - } - - pub fn set_name_len(&mut self, new_name_len: usize) -> Result<()> { - if new_name_len > self.get_name_max_len() { - return_errno!(EINVAL, "new_name_len is too big"); - } - self.c_self.msg_namelen = new_name_len as libc::socklen_t; - Ok(()) - } - - pub fn get_control_mut(&mut self) -> Option<&mut [u8]> { - self.control.as_mut().map(|control| &mut control[..]) - } - - pub fn get_control_max_len(&self) -> usize { - self.control - .as_ref() - .map(|control| control.len()) - .unwrap_or(0) - } - - pub fn set_control_len(&mut self, new_control_len: usize) -> Result<()> { - if new_control_len > self.get_control_max_len() { - return_errno!(EINVAL, "new_control_len is too big"); - } - self.c_self.msg_controllen = new_control_len; - Ok(()) - } - - pub fn get_iovs_name_and_control_mut( - &mut self, - ) -> (&mut IovsMut<'a>, Option<&mut [u8]>, Option<&mut [u8]>) { - ( - &mut self.iovs, - self.name.as_mut().map(|name| &mut name[..]), - self.control.as_mut().map(|control| &mut control[..]), - ) - } - - pub fn set_flags(&mut self, flags: MsgHdrFlags) { - self.flags = flags; - self.c_self.msg_flags = flags.bits(); - } -} - -/// This struct is used to iterate through the control messages. -/// -/// `cmsghdr` is a C struct for ancillary data object information of a unix socket. -pub struct CMessages<'a> { - buffer: &'a [u8], - current: Option<&'a libc::cmsghdr>, -} - -impl<'a> Iterator for CMessages<'a> { - type Item = CmsgData<'a>; - - fn next(&mut self) -> Option { - let cmsg = unsafe { - let mut msg: libc::msghdr = core::mem::zeroed(); - msg.msg_control = self.buffer.as_ptr() as *mut _; - msg.msg_controllen = self.buffer.len() as _; - - let cmsg = if let Some(current) = self.current { - libc::CMSG_NXTHDR(&msg, current) - } else { - libc::CMSG_FIRSTHDR(&msg) - }; - cmsg.as_ref()? - }; - - self.current = Some(cmsg); - CmsgData::try_from_cmsghdr(cmsg) - } -} - -impl<'a> CMessages<'a> { - pub fn from_bytes(msg_control: &'a mut [u8]) -> Self { - Self { - buffer: msg_control, - current: None, - } - } -} - -/// Control message data of variable type. The data resides next to `cmsghdr`. -pub enum CmsgData<'a> { - ScmRights(ScmRights<'a>), - ScmCredentials, -} - -impl<'a> CmsgData<'a> { - /// Create an `CmsgData::ScmRights` variant. - /// - /// # Safety - /// - /// `data` must contain a valid control message and the control message must be type of - /// `SOL_SOCKET` and level of `SCM_RIGHTS`. - unsafe fn as_rights(data: &'a mut [u8]) -> Self { - let scm_rights = ScmRights { data }; - CmsgData::ScmRights(scm_rights) - } - - /// Create an `CmsgData::ScmCredentials` variant. - /// - /// # Safety - /// - /// `data` must contain a valid control message and the control message must be type of - /// `SOL_SOCKET` and level of `SCM_CREDENTIALS`. - unsafe fn as_credentials(_data: &'a [u8]) -> Self { - CmsgData::ScmCredentials - } - - fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Option { - unsafe { - let cmsg_len_zero = libc::CMSG_LEN(0) as usize; - let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero; - let data = libc::CMSG_DATA(cmsg); - let data = core::slice::from_raw_parts_mut(data, data_len); - - match (*cmsg).cmsg_level { - libc::SOL_SOCKET => match (*cmsg).cmsg_type { - libc::SCM_RIGHTS => Some(CmsgData::as_rights(data)), - libc::SCM_CREDENTIALS => Some(CmsgData::as_credentials(data)), - _ => None, - }, - _ => None, - } - } - } -} - -/// The data unit of this control message is file descriptor(s). -/// -/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`. -pub struct ScmRights<'a> { - data: &'a mut [u8], -} - -impl<'a> ScmRights<'a> { - /// Iterate and reassign each fd in data buffer, given a reassignment function. - pub fn iter_and_reassign_fds(&mut self, reassign_fd_fn: F) - where - F: Fn(FileDesc) -> FileDesc, - { - for fd_bytes in self.data.chunks_exact_mut(core::mem::size_of::()) { - let old_fd = FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap()); - let reassigned_fd = reassign_fd_fn(old_fd); - fd_bytes.copy_from_slice(&reassigned_fd.to_ne_bytes()); - } - } - - pub fn iter_fds(&self) -> impl Iterator + '_ { - self.data - .chunks_exact(core::mem::size_of::()) - .map(|fd_bytes| FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap())) - } -} - -unsafe fn new_optional_slice<'a, T>(slice_ptr: *const T, slice_size: usize) -> Option<&'a [T]> { - if !slice_ptr.is_null() { - let slice = core::slice::from_raw_parts::(slice_ptr, slice_size); - Some(slice) - } else { - None - } -} - -unsafe fn new_optional_slice_mut<'a, T>( - slice_ptr: *mut T, - slice_size: usize, -) -> Option<&'a mut [T]> { - if !slice_ptr.is_null() { - let slice = core::slice::from_raw_parts_mut::(slice_ptr, slice_size); - Some(slice) - } else { - None - } -} diff --git a/src/libos/src/net/socket/shutdown.rs b/src/libos/src/net/socket/shutdown.rs deleted file mode 100644 index 8b07abb8..00000000 --- a/src/libos/src/net/socket/shutdown.rs +++ /dev/null @@ -1,28 +0,0 @@ -use super::*; - -bitflags! { - pub struct HowToShut: c_int { - const READ = 0; - const WRITE = 1; - const BOTH = 2; - } -} - -impl HowToShut { - pub fn try_from_raw(how: c_int) -> Result { - match how { - 0 => Ok(Self::READ), - 1 => Ok(Self::WRITE), - 2 => Ok(Self::BOTH), - _ => return_errno!(EINVAL, "invalid how"), - } - } - - pub fn to_shut_read(&self) -> bool { - *self == Self::READ || *self == Self::BOTH - } - - pub fn to_shut_write(&self) -> bool { - *self == Self::WRITE || *self == Self::BOTH - } -} diff --git a/src/libos/src/net/socket/socket_type.rs b/src/libos/src/net/socket/socket_type.rs deleted file mode 100644 index af1744a0..00000000 --- a/src/libos/src/net/socket/socket_type.rs +++ /dev/null @@ -1,29 +0,0 @@ -use super::*; - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -#[repr(i32)] -#[allow(non_camel_case_types)] -pub enum SocketType { - STREAM = 1, - DGRAM = 2, - RAW = 3, - RDM = 4, - SEQPACKET = 5, - DCCP = 6, - PACKET = 10, -} - -impl SocketType { - pub fn try_from(sock_type: i32) -> Result { - match sock_type { - 1 => Ok(SocketType::STREAM), - 2 => Ok(SocketType::DGRAM), - 3 => Ok(SocketType::RAW), - 4 => Ok(SocketType::RDM), - 5 => Ok(SocketType::SEQPACKET), - 6 => Ok(SocketType::DCCP), - 10 => Ok(SocketType::PACKET), - _ => return_errno!(EINVAL, "invalid socket type"), - } - } -} diff --git a/src/libos/src/net/socket/unix/addr.rs b/src/libos/src/net/socket/unix/addr.rs deleted file mode 100644 index 08925504..00000000 --- a/src/libos/src/net/socket/unix/addr.rs +++ /dev/null @@ -1,156 +0,0 @@ -use super::*; -use std::path::{Path, PathBuf}; -use std::{cmp, mem, slice, str}; - -const MAX_PATH_LEN: usize = 108; -const SUN_FAMILY_LEN: usize = mem::size_of::(); -lazy_static! { - static ref SUN_PATH_OFFSET: usize = memoffset::offset_of!(libc::sockaddr_un, sun_path); -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub enum Addr { - File(Option, UnixPath), // An optional inode number and path. Use inode if there is one. - Abstract(String), -} - -impl Addr { - /// Caller should guarentee the sockaddr and addr_len are valid. - /// The pathname should end with a '\0' within the passed length. - /// The abstract name should both start and end with a '\0' within the passed length. - pub unsafe fn try_from_raw( - sockaddr: *const libc::sockaddr, - addr_len: libc::socklen_t, - ) -> Result { - let addr_len = addr_len as usize; - - // TODO: support autobind to validate when addr_len == SUN_FAMILY_LEN - if addr_len <= SUN_FAMILY_LEN { - return_errno!(EINVAL, "the address is too short."); - } - - if addr_len > MAX_PATH_LEN + *SUN_PATH_OFFSET { - return_errno!(EINVAL, "the address is too long."); - } - - if AddressFamily::try_from((*sockaddr).sa_family)? != AddressFamily::LOCAL { - return_errno!(EINVAL, "not a valid address for unix socket"); - } - - let sockaddr = sockaddr as *const libc::sockaddr_un; - let sun_path = (*sockaddr).sun_path; - - if sun_path[0] == 0 { - let path_ptr = sun_path[1..(addr_len - *SUN_PATH_OFFSET)].as_ptr(); - let path_slice = - slice::from_raw_parts(path_ptr as *const u8, addr_len - *SUN_PATH_OFFSET - 1); - - Ok(Self::Abstract( - str::from_utf8(&path_slice).unwrap().to_string(), - )) - } else { - let path_cstr = CStr::from_ptr(sun_path.as_ptr()); - if path_cstr.to_bytes_with_nul().len() > MAX_PATH_LEN { - return_errno!(EINVAL, "no null in the address"); - } - - Ok(Self::File(None, UnixPath::new(path_cstr.to_str().unwrap()))) - } - } - - pub fn copy_to_slice(&self, dst: &mut [u8]) -> usize { - let (raw_addr, addr_len) = self.to_raw(); - let src = - unsafe { std::slice::from_raw_parts(&raw_addr as *const _ as *const u8, addr_len) }; - let copied = std::cmp::min(dst.len(), addr_len); - dst[..copied].copy_from_slice(&src[..copied]); - copied - } - - pub fn raw_len(&self) -> usize { - /// The '/0' at the end of Self::File counts - self.path_str().len() - + 1 - + *SUN_PATH_OFFSET - } - - pub fn path_str(&self) -> &str { - match self { - Self::File(_, unix_path) => &unix_path.path_str(), - Self::Abstract(path) => &path, - } - } - - fn to_raw(&self) -> (libc::sockaddr_un, usize) { - let mut addr: libc::sockaddr_un = unsafe { mem::zeroed() }; - addr.sun_family = AddressFamily::LOCAL as libc::sa_family_t; - - let addr_len = match self { - Self::File(_, unix_path) => { - let path_str = unix_path.path_str(); - let buf_len = path_str.len(); - /// addr is initialized to all zeros and try_from_raw guarentees - /// unix_path length is shorter than sun_path, so sun_path here - /// will always have a null terminator - addr.sun_path[..buf_len] - .copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) }); - buf_len + *SUN_PATH_OFFSET + 1 - } - Self::Abstract(path_str) => { - addr.sun_path[0] = 0; - let buf_len = path_str.len() + 1; - addr.sun_path[1..buf_len] - .copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) }); - buf_len + *SUN_PATH_OFFSET - } - }; - - (addr, addr_len) - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct UnixPath { - inner: PathBuf, - /// Holds the cwd when a relative path is created - cwd: Option, -} - -impl UnixPath { - pub fn new(path: &str) -> Self { - let inner = PathBuf::from(path); - let is_absolute = inner.is_absolute(); - Self { - inner: inner, - cwd: if is_absolute { - None - } else { - let thread = current!(); - let fs = thread.fs().read().unwrap(); - let cwd = fs.cwd().to_owned(); - - Some(cwd) - }, - } - } - - pub fn absolute(&self) -> String { - let path_str = self.path_str(); - if self.inner.is_absolute() { - path_str.to_string() - } else { - let mut prefix = self.cwd.as_ref().unwrap().clone(); - if prefix.ends_with("/") { - prefix.push_str(path_str); - } else { - prefix.push_str("/"); - prefix.push_str(path_str); - } - prefix - } - } - - pub fn path_str(&self) -> &str { - self.inner.to_str().unwrap() - } -} diff --git a/src/libos/src/net/socket/unix/mod.rs b/src/libos/src/net/socket/unix/mod.rs index 525d9f39..f8260232 100644 --- a/src/libos/src/net/socket/unix/mod.rs +++ b/src/libos/src/net/socket/unix/mod.rs @@ -1,19 +1,16 @@ -use self::addr::Addr; use super::*; -mod addr; mod stream; -pub use self::addr::Addr as UnixAddr; pub use self::stream::Stream; //TODO: rewrite this file when a new kind of uds is added -pub fn unix_socket(socket_type: SocketType, flags: FileFlags, protocol: i32) -> Result { - if protocol != 0 && protocol != AddressFamily::LOCAL as i32 { +pub fn unix_socket(socket_type: Type, flags: SocketFlags, protocol: i32) -> Result { + if protocol != 0 && protocol != Domain::LOCAL as i32 { return_errno!(EPROTONOSUPPORT, "protocol is not supported"); } - if socket_type == SocketType::STREAM { + if socket_type == Type::STREAM { Ok(Stream::new(flags)) } else { return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); @@ -21,15 +18,15 @@ pub fn unix_socket(socket_type: SocketType, flags: FileFlags, protocol: i32) -> } pub fn socketpair( - socket_type: SocketType, - flags: FileFlags, + socket_type: Type, + flags: SocketFlags, protocol: i32, ) -> Result<(Stream, Stream)> { - if protocol != 0 && protocol != AddressFamily::LOCAL as i32 { + if protocol != 0 && protocol != Domain::LOCAL as i32 { return_errno!(EPROTONOSUPPORT, "protocol is not supported"); } - if socket_type == SocketType::STREAM { + if socket_type == Type::STREAM { Stream::socketpair(flags) } else { return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); diff --git a/src/libos/src/net/socket/unix/stream/address_space.rs b/src/libos/src/net/socket/unix/stream/address_space.rs index 551a8ca6..24485f98 100644 --- a/src/libos/src/net/socket/unix/stream/address_space.rs +++ b/src/libos/src/net/socket/unix/stream/address_space.rs @@ -39,9 +39,9 @@ impl AddressSpace { } } - pub fn add_binder(&self, addr: &Addr) -> Result<()> { + pub fn add_binder(&self, addr: &UnixAddr) -> Result<()> { let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "can't find socket file"))?; - let mut space = self.get_space(addr); + let mut space = self.get_space(addr)?; if space.contains_key(&key) { return_errno!(EADDRINUSE, "the addr is already bound"); } else { @@ -52,13 +52,13 @@ impl AddressSpace { pub(super) fn add_listener( &self, - addr: &Addr, + addr: &UnixAddr, capacity: usize, nonblocking: bool, notifier: Arc, ) -> Result<()> { let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?; - let mut space = self.get_space(addr); + let mut space = self.get_space(addr)?; if let Some(option) = space.get(&key) { if option.is_none() { @@ -75,9 +75,9 @@ impl AddressSpace { } } - pub fn resize_listener(&self, addr: &Addr, capacity: usize) -> Result<()> { + pub fn resize_listener(&self, addr: &UnixAddr, capacity: usize) -> Result<()> { let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?; - let mut space = self.get_space(addr); + let mut space = self.get_space(addr)?; if let Some(option) = space.get(&key) { if let Some(listener) = option { @@ -91,33 +91,33 @@ impl AddressSpace { } } - pub fn push_incoming(&self, addr: &Addr, sock: Endpoint) -> Result<()> { + pub fn push_incoming(&self, addr: &UnixAddr, sock: Endpoint) -> Result<()> { self.get_listener_ref(addr) .ok_or_else(|| errno!(ECONNREFUSED, "no one's listening on the remote address"))? .push_incoming(sock) } - pub fn pop_incoming(&self, addr: &Addr) -> Result { + pub fn pop_incoming(&self, addr: &UnixAddr) -> Result { self.get_listener_ref(addr) .ok_or_else(|| errno!(EINVAL, "the socket is not listening"))? .pop_incoming() .ok_or_else(|| errno!(EAGAIN, "No connection is incoming")) } - pub fn get_listener_ref(&self, addr: &Addr) -> Option> { + pub fn get_listener_ref(&self, addr: &UnixAddr) -> Option> { let key = Self::get_key(addr); if let Some(key) = key { - let space = self.get_space(addr); + let space = self.get_space(addr).unwrap(); space.get(&key).map(|x| x.clone()).flatten() } else { None } } - pub fn remove_addr(&self, addr: &Addr) { + pub fn remove_addr(&self, addr: &UnixAddr) { let key = Self::get_key(addr); if let Some(key) = key { - let mut space = self.get_space(addr); + let mut space = self.get_space(addr).unwrap(); space.remove(&key); } else { warn!("address space key not exit: {:?}", addr); @@ -126,21 +126,22 @@ impl AddressSpace { fn get_space( &self, - addr: &Addr, - ) -> SgxMutexGuard<'_, BTreeMap>>> { + addr: &UnixAddr, + ) -> Result>>>> { match addr { - Addr::File(_, _) => self.file.lock().unwrap(), - Addr::Abstract(_) => self.abstr.lock().unwrap(), + UnixAddr::File(_, _) => Ok(self.file.lock().unwrap()), + UnixAddr::Abstract(_) => Ok(self.abstr.lock().unwrap()), + UnixAddr::Unnamed => return_errno!(EINVAL, "can't get path name for unnamed socket"), } } - fn get_key(addr: &Addr) -> Option { + fn get_key(addr: &UnixAddr) -> Option { trace!("addr = {:?}", addr); match addr { - Addr::File(inode_num, unix_path) if inode_num.is_some() => { + UnixAddr::File(inode_num, unix_path) if inode_num.is_some() => { Some(AddressSpaceKey::from_inode(inode_num.unwrap())) } - Addr::File(_, unix_path) => { + UnixAddr::File(_, unix_path) => { let inode = { let file_path = unix_path.absolute(); let current = current!(); @@ -153,7 +154,10 @@ impl AddressSpace { None } } - Addr::Abstract(path) => Some(AddressSpaceKey::from_path(addr.path_str().to_string())), + UnixAddr::Abstract(path) => Some(AddressSpaceKey::from_path( + addr.path_str().unwrap().to_string(), + )), + UnixAddr::Unnamed => None, } } } diff --git a/src/libos/src/net/socket/unix/stream/endpoint.rs b/src/libos/src/net/socket/unix/stream/endpoint.rs index 9ddbb7f0..61f36d71 100644 --- a/src/libos/src/net/socket/unix/stream/endpoint.rs +++ b/src/libos/src/net/socket/unix/stream/endpoint.rs @@ -39,7 +39,7 @@ pub fn end_pair(nonblocking: bool) -> Result<(Endpoint, Endpoint)> { /// One end of the connected unix socket pub struct Inner { - addr: RwLock>, + addr: RwLock>, reader: Consumer, writer: Producer, peer: Weak, @@ -47,15 +47,15 @@ pub struct Inner { } impl Inner { - pub fn addr(&self) -> Option { + pub fn addr(&self) -> Option { self.addr.read().unwrap().clone() } - pub fn set_addr(&self, addr: &Addr) { + pub fn set_addr(&self, addr: &UnixAddr) { *self.addr.write().unwrap() = Some(addr.clone()); } - pub fn peer_addr(&self) -> Option { + pub fn peer_addr(&self) -> Option { self.peer.upgrade().map(|end| end.addr().clone()).flatten() } @@ -90,16 +90,16 @@ impl Inner { self.reader.items_to_consume() } - pub fn shutdown(&self, how: HowToShut) -> Result<()> { + pub fn shutdown(&self, how: Shutdown) -> Result<()> { if !self.is_connected() { return_errno!(ENOTCONN, "The socket is not connected."); } - if how.to_shut_read() { + if how.should_shut_read() { self.reader.shutdown() } - if how.to_shut_write() { + if how.should_shut_write() { self.writer.shutdown() } diff --git a/src/libos/src/net/socket/unix/stream/file.rs b/src/libos/src/net/socket/unix/stream/file.rs index 9fd877bc..7bfee49a 100644 --- a/src/libos/src/net/socket/unix/stream/file.rs +++ b/src/libos/src/net/socket/unix/stream/file.rs @@ -1,10 +1,12 @@ use super::address_space::ADDRESS_SPACE; use super::stream::Status; use super::*; -use fs::{AccessMode, File, FileRef, IoEvents, IoNotifier, IoctlCmd, StatusFlags}; +use fs::{AccessMode, File, IoEvents, IoNotifier, StatusFlags}; use rcore_fs::vfs::{FileType, Metadata, Timespec}; use std::any::Any; +use crate::fs::{GetReadBufLen, IoctlCmd, SetNonBlocking}; + impl File for Stream { fn read(&self, buf: &mut [u8]) -> Result { // The connected status will not be changed any more @@ -55,23 +57,23 @@ impl File for Stream { } } - fn ioctl(&self, cmd: &mut IoctlCmd) -> Result { - match cmd { - IoctlCmd::TCGETS(_) => return_errno!(ENOTTY, "not tty device"), - IoctlCmd::TCSETS(_) => return_errno!(ENOTTY, "not tty device"), - IoctlCmd::FIONBIO(nonblocking) => { - self.set_nonblocking(**nonblocking != 0); - } - IoctlCmd::FIONREAD(arg) => match &*self.inner() { - Status::Connected(endpoint) => { - let bytes_to_read = endpoint.bytes_to_read().min(std::i32::MAX as usize) as i32; - **arg = bytes_to_read; - } - _ => return_errno!(ENOTCONN, "unconnected socket"), + fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { + match_ioctl_cmd_auto_error!(cmd, { + cmd : GetReadBufLen => { + match &*self.inner() { + Status::Connected(endpoint) => { + let bytes_to_read = endpoint.bytes_to_read().min(std::i32::MAX as usize) as i32; + cmd.set_output(bytes_to_read as _); + } + _ => return_errno!(ENOTCONN, "unconnected socket"), + }; }, - _ => return_errno!(EINVAL, "unknown ioctl cmd for unix socket"), - } - Ok(0) + cmd : SetNonBlocking => { + let nonblocking = cmd.input(); + self.set_nonblocking(*nonblocking != 0); + } + }); + Ok(()) } fn access_mode(&self) -> Result { diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs index 4bea1a07..fc1711e2 100644 --- a/src/libos/src/net/socket/unix/stream/stream.rs +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -5,7 +5,7 @@ use events::{Event, EventFilter, Notifier, Observer}; use fs::channel::Channel; use fs::IoEvents; use fs::{CreationFlags, FileMode}; -use net::socket::{CMessages, CmsgData, Iovs, MsgHdr, MsgHdrMut}; +use net::socket::{CMessages, CmsgData}; use std::fmt; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -23,17 +23,17 @@ pub struct Stream { } impl Stream { - pub fn new(flags: FileFlags) -> Self { + pub fn new(flags: SocketFlags) -> Self { Self { inner: SgxMutex::new(Status::Idle(Info::new( - flags.contains(FileFlags::SOCK_NONBLOCK), + flags.contains(SocketFlags::SOCK_NONBLOCK), ))), notifier: Arc::new(RelayNotifier::new()), } } - pub fn socketpair(flags: FileFlags) -> Result<(Self, Self)> { - let nonblocking = flags.contains(FileFlags::SOCK_NONBLOCK); + pub fn socketpair(flags: SocketFlags) -> Result<(Self, Self)> { + let nonblocking = flags.contains(SocketFlags::SOCK_NONBLOCK); let (end_a, end_b) = end_pair(nonblocking)?; let notifier_a = Arc::new(RelayNotifier::new()); let notifier_b = Arc::new(RelayNotifier::new()); @@ -53,15 +53,17 @@ impl Stream { Ok((socket_a, socket_b)) } - pub fn addr(&self) -> Option { - match &*self.inner() { + pub fn addr(&self) -> UnixAddr { + let addr_opt = match &*self.inner() { Status::Idle(info) => info.addr().clone(), Status::Connected(endpoint) => endpoint.addr(), Status::Listening(addr) => Some(addr).cloned(), - } + }; + + addr_opt.unwrap_or(UnixAddr::Unnamed) } - pub fn peer_addr(&self) -> Result { + pub fn peer_addr(&self) -> Result { if let Status::Connected(endpoint) = &*self.inner() { if let Some(addr) = endpoint.peer_addr() { return Ok(addr); @@ -70,8 +72,8 @@ impl Stream { return_errno!(ENOTCONN, "the socket is not connected"); } - pub fn bind(&self, addr: &mut Addr) -> Result<()> { - if let Addr::File(inode_num, path) = addr { + pub fn bind(&self, addr: &mut UnixAddr) -> Result<()> { + if let UnixAddr::File(inode_num, path) = addr { // create the corresponding file in the fs and fill Addr with its inode let corresponding_inode_num = { let current = current!(); @@ -143,7 +145,7 @@ impl Stream { /// The establishment of the connection is very fast and can be done immediately. /// Therefore, the connect function in our implementation will never block. - pub fn connect(&self, addr: &Addr) -> Result<()> { + pub fn connect(&self, addr: &UnixAddr) -> Result<()> { debug!("connect to {:?}", addr); let mut inner = self.inner(); @@ -168,7 +170,7 @@ impl Stream { ADDRESS_SPACE .push_incoming(addr, end_incoming) .map_err(|e| match e.errno() { - Errno::EAGAIN => errno!(ECONNREFUSED, "the backlog is full"), + EAGAIN => errno!(ECONNREFUSED, "the backlog is full"), _ => e, })?; @@ -187,12 +189,12 @@ impl Stream { } } - pub fn accept(&self, flags: FileFlags) -> Result<(Self, Option)> { + pub fn accept(&self, flags: SocketFlags) -> Result<(Self, Option)> { let status = (*self.inner()).clone(); match status { Status::Listening(addr) => { let endpoint = ADDRESS_SPACE.pop_incoming(&addr)?; - endpoint.set_nonblocking(flags.contains(FileFlags::SOCK_NONBLOCK)); + endpoint.set_nonblocking(flags.contains(SocketFlags::SOCK_NONBLOCK)); endpoint.set_ancillary(Ancillary { tid: current!().tid(), }); @@ -216,12 +218,12 @@ impl Stream { } // TODO: handle flags - pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: &Option) -> Result { + pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: &Option) -> Result { self.write(buf) } // TODO: handle flags - pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { + pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { let data_len = self.read(buf)?; let addr = self.peer_addr().ok(); @@ -230,33 +232,39 @@ impl Stream { Ok((data_len, addr)) } - pub fn sendmsg(&self, msg_hdr: &MsgHdr, flags: SendFlags) -> Result { + pub fn sendmsg( + &self, + bufs: &[&[u8]], + flags: SendFlags, + control: Option<&[u8]>, + ) -> Result { if !flags.is_empty() { warn!("unsupported flags: {:?}", flags); } - let bufs = msg_hdr.get_iovs().as_slices(); - let mut data_len = self.writev(bufs)?; - - if let Some(msg_control) = msg_hdr.get_control() { - data_len += self.write(msg_control)?; + let data_len = self.writev(bufs)?; + if let Some(msg_control) = control { + self.write(msg_control)?; } + Ok(data_len) } - pub fn recvmsg(&self, msg_hdr: &mut MsgHdrMut, flags: RecvFlags) -> Result { + pub fn recvmsg( + &self, + bufs: &mut [&mut [u8]], + flags: RecvFlags, + control: Option<&mut [u8]>, + ) -> Result<(usize, usize)> { if !flags.is_empty() { warn!("unsupported flags: {:?}", flags); } - let bufs = msg_hdr.get_iovs_mut().as_slices_mut(); - let mut data_len = self.readv(bufs)?; + let data_len = self.readv(bufs)?; // For stream socket, the msg_name is ignored. And other fields are not supported. - msg_hdr.set_name_len(0); - - if let Some(msg_control) = msg_hdr.get_control_mut() { - data_len += self.read(msg_control)?; + let control_len = if let Some(msg_control) = control { + let control_len = self.read(msg_control)?; // For each control message that contains file descriptors (SOL_SOCKET and SCM_RIGHTS), // reassign each fd in the message in receive end. @@ -268,7 +276,6 @@ impl Stream { .unwrap() .files() .lock() - .unwrap() .get(send_fd) .unwrap(); current!().add_file(ipc_file.clone(), false) @@ -276,12 +283,16 @@ impl Stream { } // Unix credentials need not to be handled here } - } - Ok(data_len) + control_len + } else { + 0 + }; + + Ok((data_len, control_len)) } /// perform shutdown on the socket. - pub fn shutdown(&self, how: HowToShut) -> Result<()> { + pub fn shutdown(&self, how: Shutdown) -> Result<()> { if let Status::Connected(ref end) = &*self.inner() { end.shutdown(how) } else { @@ -369,13 +380,13 @@ pub enum Status { Idle(Info), // The listeners are stored in a global data structure indexed by the address. // The consitency of Status with that data structure should be carefully maintained. - Listening(Addr), + Listening(UnixAddr), Connected(Endpoint), } #[derive(Debug, Clone)] pub struct Info { - addr: Option, + addr: Option, nonblocking: bool, } @@ -387,11 +398,11 @@ impl Info { } } - pub fn addr(&self) -> &Option { + pub fn addr(&self) -> &Option { &self.addr } - pub fn set_addr(&mut self, addr: &Addr) { + pub fn set_addr(&mut self, addr: &UnixAddr) { self.addr = Some(addr.clone()); } diff --git a/src/libos/src/net/socket/util/addr/c_sock_addr.rs b/src/libos/src/net/socket/util/addr/c_sock_addr.rs new file mode 100644 index 00000000..04138c97 --- /dev/null +++ b/src/libos/src/net/socket/util/addr/c_sock_addr.rs @@ -0,0 +1,134 @@ +use crate::prelude::*; +use std::mem::{size_of, size_of_val, MaybeUninit}; +/// A trait for all C version of C socket addresses. +/// +/// There are four types that implement this trait: +/// * `libc::sockaddr_in` +/// * `(libc::sockaddr_in, usize)` +/// * `(libc::sockaddr_un, usize)` +/// * `(libc::sockaddr_storage, usize)`. +pub trait CSockAddr { + /// The network family of the address. + fn c_family(&self) -> libc::sa_family_t; + + /// The address in bytes (excluding the family part). + fn c_addr(&self) -> &[u8]; + + /// Returns the address in `libc::sockaddr_storage` along with its length. + fn to_c_storage(&self) -> (libc::sockaddr_storage, usize) { + let mut c_storage = + unsafe { MaybeUninit::::uninit().assume_init() }; + + c_storage.ss_family = self.c_family(); + let offset = size_of_val(&c_storage.ss_family); + + let c_storage_len = offset + self.c_addr().len(); + assert!(c_storage_len <= size_of::()); + + let c_storage_remain = unsafe { + let ptr = (&mut c_storage as *mut _ as *mut u8).add(offset); + let len = self.c_addr().len(); + std::slice::from_raw_parts_mut(ptr, len) + }; + c_storage_remain.copy_from_slice(self.c_addr()); + (c_storage, c_storage_len) + } +} + +impl CSockAddr for libc::sockaddr_in { + fn c_family(&self) -> libc::sa_family_t { + libc::AF_INET as _ + } + + fn c_addr(&self) -> &[u8] { + // Safety. The slice is part of self. + unsafe { + let addr_ptr = (self as *const _ as *const u8).add(size_of_val(&self.sin_family)); + std::slice::from_raw_parts( + addr_ptr, + size_of::() - size_of_val(&self.sin_family), + ) + } + } +} + +impl CSockAddr for (libc::sockaddr_in, usize) { + fn c_family(&self) -> libc::sa_family_t { + self.0.c_family() + } + + fn c_addr(&self) -> &[u8] { + assert!(self.1 == size_of::()); + self.0.c_addr() + } +} + +impl CSockAddr for (libc::sockaddr_in6, usize) { + fn c_family(&self) -> libc::sa_family_t { + self.0.sin6_family + } + + fn c_addr(&self) -> &[u8] { + assert!(self.1 == size_of::()); + unsafe { + let addr_ptr = (&self.0 as *const _ as *const u8).add(size_of_val(&self.c_family())); + std::slice::from_raw_parts( + addr_ptr, + size_of::() - size_of_val(&self.c_family()), + ) + } + } +} + +impl CSockAddr for (libc::sockaddr_un, usize) { + fn c_family(&self) -> libc::sa_family_t { + libc::AF_UNIX as _ + } + + fn c_addr(&self) -> &[u8] { + assert!( + size_of::() <= self.1 && self.1 <= size_of::() + ); + // Safety. The slice is part of self. + unsafe { + let addr_ptr = (&self.0 as *const _ as *const u8).add(size_of_val(&self.0.sun_family)); + std::slice::from_raw_parts(addr_ptr, self.1 - size_of_val(&self.0.sun_family)) + } + } +} + +impl CSockAddr for (libc::sockaddr_nl, usize) { + fn c_family(&self) -> libc::sa_family_t { + libc::AF_NETLINK as _ + } + + fn c_addr(&self) -> &[u8] { + assert!(self.1 == size_of::()); + + unsafe { + let addr_ptr = (&self.0 as *const _ as *const u8).add(size_of_val(&self.c_family())); + std::slice::from_raw_parts( + addr_ptr, + size_of::() - size_of_val(&self.c_family()), + ) + } + } +} + +impl CSockAddr for (libc::sockaddr_storage, usize) { + fn c_family(&self) -> libc::sa_family_t { + self.0.ss_family + } + + fn c_addr(&self) -> &[u8] { + assert!( + size_of::() <= self.1 + && self.1 <= size_of::() + ); + // Safety. The slice is part of self. + unsafe { + let addr_ptr = (&self.0 as *const _ as *const u8).add(size_of_val(&self.0.ss_family)); + std::slice::from_raw_parts(addr_ptr, self.1 - size_of_val(&self.0.ss_family)) + } + } +} diff --git a/src/libos/src/net/socket/util/addr/ipv4.rs b/src/libos/src/net/socket/util/addr/ipv4.rs new file mode 100644 index 00000000..29b818ce --- /dev/null +++ b/src/libos/src/net/socket/util/addr/ipv4.rs @@ -0,0 +1,137 @@ +use std::any::Any; +use std::fmt::{self, Debug}; + +use super::{Addr, CSockAddr, Domain, RawAddr}; +use crate::prelude::*; + +/// An IPv4 socket address, consisting of an IPv4 address and a port. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Ipv4SocketAddr { + ip: Ipv4Addr, + port: u16, +} + +impl Addr for Ipv4SocketAddr { + fn domain() -> Domain { + Domain::INET + } + + fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Result { + if c_addr_len > std::mem::size_of::() { + return_errno!(EINVAL, "address length is too large"); + } + + // The c_addr_len is certainly not smaller than the length of IN_ADDR_ANY. + // https://en.wikipedia.org/wiki/IPv4 + if c_addr_len < std::mem::size_of::() { + return_errno!(EINVAL, "address length is too small"); + } + // Safe to convert from sockaddr_storage to sockaddr_in + let c_addr = unsafe { std::mem::transmute(c_addr) }; + Self::from_c(c_addr) + } + + fn to_c_storage(&self) -> (libc::sockaddr_storage, usize) { + let c_in_addr = self.to_c(); + c_in_addr.to_c_storage() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn is_default(&self) -> bool { + let inaddr_any = Self::default(); + *self == inaddr_any + } +} + +impl Ipv4SocketAddr { + pub fn new(ip: Ipv4Addr, port: u16) -> Self { + Self { ip, port } + } + + pub fn from_c(c_addr: &libc::sockaddr_in) -> Result { + if c_addr.sin_family != libc::AF_INET as libc::sa_family_t { + return_errno!(EINVAL, "an ipv4 address is expected"); + } + Ok(Self { + port: u16::from_be(c_addr.sin_port), + ip: Ipv4Addr::from_c(&c_addr.sin_addr), + }) + } + + pub fn to_c(&self) -> libc::sockaddr_in { + libc::sockaddr_in { + sin_family: libc::AF_INET as _, + sin_port: self.port.to_be(), + sin_addr: self.ip.to_c(), + sin_zero: [0; 8], + } + } + + pub fn to_raw(&self) -> RawAddr { + let (storage, len) = self.to_c_storage(); + RawAddr::from_c_storage(&storage, len) + } + + pub fn ip(&self) -> &Ipv4Addr { + &self.ip + } + + pub fn port(&self) -> u16 { + self.port + } + + pub fn set_ip(&mut self, new_ip: Ipv4Addr) { + self.ip = new_ip; + } + + pub fn set_port(&mut self, new_port: u16) { + self.port = new_port; + } +} + +impl Default for Ipv4SocketAddr { + fn default() -> Self { + let addr = Ipv4Addr::new(0, 0, 0, 0); + Self::new(addr, 0) + } +} + +/// An Ipv4 address. +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct Ipv4Addr([u8; 4] /* big endian */); + +impl Ipv4Addr { + /// Creates a new IPv4 address of `a.b.c.d`. + pub fn new(a: u8, b: u8, c: u8, d: u8) -> Self { + Self([a, b, c, d]) + } + + /// Creates a new IPv4 address from its C counterpart. + pub fn from_c(c_addr: &libc::in_addr) -> Self { + Self(c_addr.s_addr.to_ne_bytes()) + } + + /// Return the C counterpart. + pub fn to_c(&self) -> libc::in_addr { + libc::in_addr { + s_addr: u32::from_ne_bytes(self.0), + } + } + + /// Return the four digits that make up the address. + /// + /// Assuming the address is `a.b.c.d`, the returned value would be `[a, b, c, d]`. + pub fn octets(&self) -> &[u8; 4] { + &self.0 + } +} + +impl fmt::Debug for Ipv4Addr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let [a, b, c, d] = *self.octets(); + write!(f, "Ipv4Addr ({}.{}.{}.{})", &a, &b, &c, &d) + } +} diff --git a/src/libos/src/net/socket/util/addr/ipv6.rs b/src/libos/src/net/socket/util/addr/ipv6.rs new file mode 100644 index 00000000..764c0a54 --- /dev/null +++ b/src/libos/src/net/socket/util/addr/ipv6.rs @@ -0,0 +1,115 @@ +use std::any::Any; +use std::fmt::Debug; + +use super::RawAddr; +use super::{Addr, CSockAddr, Domain}; +use crate::prelude::*; +use libc::in6_addr; +use libc::sockaddr_in6; + +pub use std::net::Ipv6Addr; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Ipv6SocketAddr { + ip: Ipv6Addr, + port: u16, + flowinfo: u32, + scope_id: u32, +} + +impl Addr for Ipv6SocketAddr { + fn domain() -> Domain { + Domain::INET6 + } + + fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Result { + if c_addr_len > std::mem::size_of::() { + return_errno!(EINVAL, "address length is too large"); + } + + if c_addr_len < std::mem::size_of::() { + return_errno!(EINVAL, "address length is too small"); + } + // Safe to convert from sockaddr_storage to sockaddr_in + let c_addr: &sockaddr_in6 = unsafe { std::mem::transmute(c_addr) }; + Self::from_c(c_addr) + } + + fn to_c_storage(&self) -> (libc::sockaddr_storage, usize) { + let c_addr = self.to_c(); + (c_addr, std::mem::size_of::()).to_c_storage() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn is_default(&self) -> bool { + let in6addr_any_init = Self::default(); + *self == in6addr_any_init + } +} + +impl Ipv6SocketAddr { + pub fn new(ip: Ipv6Addr, port: u16, flowinfo: u32, scope_id: u32) -> Self { + Self { + ip, + port, + flowinfo, + scope_id, + } + } + + pub fn from_c(c_addr: &libc::sockaddr_in6) -> Result { + if c_addr.sin6_family != libc::AF_INET6 as libc::sa_family_t { + return_errno!(EINVAL, "an ipv6 address is expected"); + } + Ok(Self { + port: u16::from_be(c_addr.sin6_port), + ip: Ipv6Addr::from(c_addr.sin6_addr.s6_addr), + flowinfo: u32::from_be(c_addr.sin6_flowinfo), + scope_id: u32::from_be(c_addr.sin6_scope_id), + }) + } + + pub fn to_c(&self) -> libc::sockaddr_in6 { + let in6_addr = in6_addr { + s6_addr: self.ip.octets(), + }; + libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as _, + sin6_port: self.port.to_be(), + sin6_addr: in6_addr, + sin6_flowinfo: self.flowinfo.to_be(), + sin6_scope_id: self.flowinfo.to_be(), + } + } + + pub fn to_raw(&self) -> RawAddr { + let (storage, len) = self.to_c_storage(); + RawAddr::from_c_storage(&storage, len) + } + + pub fn ip(&self) -> &Ipv6Addr { + &self.ip + } + + pub fn port(&self) -> u16 { + self.port + } + + pub fn set_ip(&mut self, new_ip: Ipv6Addr) { + self.ip = new_ip; + } + + pub fn set_port(&mut self, new_port: u16) { + self.port = new_port; + } +} + +impl Default for Ipv6SocketAddr { + fn default() -> Self { + let addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0); + Self::new(addr, 0, 0, 0) + } +} diff --git a/src/libos/src/net/socket/util/addr/mod.rs b/src/libos/src/net/socket/util/addr/mod.rs new file mode 100644 index 00000000..66e08b2d --- /dev/null +++ b/src/libos/src/net/socket/util/addr/mod.rs @@ -0,0 +1,87 @@ +use std::any::Any; +use std::fmt::Debug; + +use crate::net::Domain; +use crate::prelude::*; + +mod c_sock_addr; +mod ipv4; +mod ipv6; +mod raw_addr; +mod unix_addr; + +/// A trait for network addresses. +pub trait Addr: Clone + Debug + Default + PartialEq + Send + Sync { + /// Return the domain that the address belongs to. + fn domain() -> Domain + where + Self: Sized; + + /// Construct a new address from C's sockaddr_storage. + /// + /// The length argument specify how much bytes in the given sockaddr_storage are to be + /// interpreted as parts of the address. + fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Result + where + Self: Sized; + + /// Converts the address to C's sockaddr_storage. + /// + /// The actual length used in sockaddr_storage is also returned. + fn to_c_storage(&self) -> (libc::sockaddr_storage, usize); + + fn as_any(&self) -> &dyn Any; + + fn is_default(&self) -> bool; +} + +pub use self::c_sock_addr::CSockAddr; +pub use self::ipv4::{Ipv4Addr, Ipv4SocketAddr}; +pub use self::ipv6::{Ipv6Addr, Ipv6SocketAddr}; +pub use self::raw_addr::RawAddr; +pub use self::unix_addr::UnixAddr; + +#[cfg(test)] +mod tests { + use std::mem::size_of; + + use super::*; + + #[test] + fn ipv4_to_and_from_c() { + let addr = [127u8, 0, 0, 1]; + let port = 8888u16; + + let c_addr = libc::sockaddr_in { + sin_family: libc::AF_INET as _, + sin_port: port.to_be(), + sin_addr: libc::in_addr { + s_addr: u32::from_be_bytes(addr).to_be(), + }, + sin_zero: [0u8; 8], + }; + + let addr = { + let addr = Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]); + Ipv4SocketAddr::new(addr, port) + }; + + check_to_and_from_c(&c_addr, &addr); + } + + fn check_to_and_from_c(c_addr: &T, addr: &U) { + let c_addr_storage = c_addr.to_c_storage(); + + // To C + assert!(are_sock_addrs_equal(c_addr, &addr.to_c_storage())); + assert!(are_sock_addrs_equal(&c_addr_storage, &addr.to_c_storage())); + + // From C + let (c_addr_storage, c_addr_len) = c_addr_storage; + assert!(&U::from_c_storage(&c_addr_storage, c_addr_len).unwrap() == addr); + } + + fn are_sock_addrs_equal(one: &T, other: &U) -> bool { + one.c_family() == other.c_family() && one.c_addr() == other.c_addr() + } +} diff --git a/src/libos/src/net/socket/socket_address.rs b/src/libos/src/net/socket/util/addr/raw_addr.rs similarity index 84% rename from src/libos/src/net/socket/socket_address.rs rename to src/libos/src/net/socket/util/addr/raw_addr.rs index 13612898..3f355c14 100644 --- a/src/libos/src/net/socket/socket_address.rs +++ b/src/libos/src/net/socket/util/addr/raw_addr.rs @@ -2,25 +2,33 @@ use super::*; use std::*; #[derive(Copy, Clone)] -pub struct SockAddr { +pub struct RawAddr { storage: libc::sockaddr_storage, len: usize, } // TODO: add more fields -impl fmt::Debug for SockAddr { +impl fmt::Debug for RawAddr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SockAddr") - .field( - "family", - &AddressFamily::try_from(self.storage.ss_family).unwrap(), - ) + f.debug_struct("RawAddr") + .field("family", &Domain::try_from(self.storage.ss_family).unwrap()) .field("len", &self.len) .finish() } } -impl SockAddr { +impl RawAddr { + pub fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Self { + Self { + storage: *c_addr, + len: c_addr_len, + } + } + + pub fn to_c_storage(&self) -> (libc::sockaddr_storage, usize) { + (self.storage, self.len) + } + // Caller should guarentee the sockaddr and addr_len are valid pub unsafe fn try_from_raw( sockaddr: *const libc::sockaddr, @@ -34,13 +42,13 @@ impl SockAddr { return_errno!(EINVAL, "the address is too long."); } - match AddressFamily::try_from((*sockaddr).sa_family)? { - AddressFamily::INET => { + match Domain::try_from((*sockaddr).sa_family)? { + Domain::INET => { if addr_len < std::mem::size_of::() as u32 { return_errno!(EINVAL, "short ipv4 address."); } } - AddressFamily::INET6 => { + Domain::INET6 => { let ipv6_addr_len = std::mem::size_of::() as u32; // Omit sin6_scope_id when it is not fully provided // 4 represents the size of sin6_scope_id which is not a must @@ -110,7 +118,7 @@ impl SockAddr { } } -impl Default for SockAddr { +impl Default for RawAddr { fn default() -> Self { let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; Self { diff --git a/src/libos/src/net/socket/util/addr/unix_addr.rs b/src/libos/src/net/socket/util/addr/unix_addr.rs new file mode 100644 index 00000000..4a4ac271 --- /dev/null +++ b/src/libos/src/net/socket/util/addr/unix_addr.rs @@ -0,0 +1,212 @@ +use crate::net::socket::CSockAddr; + +use super::*; +use sgx_trts::c_str::CStr; +use std::path::{Path, PathBuf}; +use std::{cmp, mem, slice, str}; + +const MAX_PATH_LEN: usize = 108; +const SUN_FAMILY_LEN: usize = mem::size_of::(); +lazy_static! { + static ref SUN_PATH_OFFSET: usize = memoffset::offset_of!(libc::sockaddr_un, sun_path); +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum UnixAddr { + Unnamed, + File(Option, UnixPath), // An optional inode number and path. Use inode if there is one. + Abstract(String), +} + +impl UnixAddr { + /// Construct a unix address from its C counterpart. + /// + /// The argument `c_len` specifies the length of the valid part in the given + /// C address. + pub fn from_c(c_addr: &libc::sockaddr_un, c_len: usize) -> Result { + // Sanity checks + if c_addr.sun_family != libc::AF_UNIX as libc::sa_family_t { + return_errno!(EINVAL, "an unix address is expected"); + } + if c_len < std::mem::size_of::() { + return_errno!(EINVAL, "the length of the address is too small"); + } else if c_len > std::mem::size_of::() { + return_errno!(EINVAL, "the length of the address is too large"); + } + + if c_len == std::mem::size_of::() { + return Ok(Self::Unnamed); + } + + let path_len = c_len - std::mem::size_of::(); + debug_assert!(path_len > 1); + if path_len == 1 { + // Both pathname and abstract addresses require a path_len greater than 1. + return_errno!(EINVAL, "the pathname must not be empty"); + } + + // A pathname address + if c_addr.sun_path[0] != 0 { + // More sanity check + let last_char = c_addr.sun_path[path_len - 1]; + if last_char != 0 { + return_errno!(EINVAL, "the pathname is not null-terminated"); + } + + let pathname = { + // Safety. Converting from &[c_char] to &[i8] is harmless. + let path_slice: &[i8] = unsafe { + let char_slice = &c_addr.sun_path[..(path_len - 1)]; + std::mem::transmute(char_slice) + }; + let path_cstr = unsafe { CStr::from_ptr(path_slice.as_ptr()) }; + if path_cstr.to_bytes_with_nul().len() > MAX_PATH_LEN { + return_errno!(EINVAL, "no null in the address"); + } + path_cstr + .to_str() + .map_err(|_| errno!(EINVAL, "path is not UTF8"))? + .to_string() + }; + + Ok(Self::File(None, UnixPath::new(&pathname))) + } + // An abstract address + else { + // Safety. Converting from &[c_char] to &[u8] is harmless. + let u8_slice: &[u8] = unsafe { + let char_slice = &c_addr.sun_path[1..(path_len)]; + std::mem::transmute(char_slice) + }; + Ok(Self::Abstract( + str::from_utf8(u8_slice).unwrap().to_string(), + )) + } + } + + pub fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Result { + if (c_addr_len) > std::mem::size_of::() { + return_errno!(EINVAL, "address length is too large"); + } + // Safety. Convert from sockaddr_storage to sockaddr_un is harmless. + let c_addr = unsafe { std::mem::transmute(c_addr) }; + unsafe { Self::from_c(c_addr, c_addr_len) } + } + + pub fn copy_to_slice(&self, dst: &mut [u8]) -> usize { + let (raw_addr, addr_len) = self.to_c(); + let src = + unsafe { std::slice::from_raw_parts(&raw_addr as *const _ as *const u8, addr_len) }; + let copied = std::cmp::min(dst.len(), addr_len); + dst[..copied].copy_from_slice(&src[..copied]); + copied + } + + pub fn raw_len(&self) -> usize { + /// The '/0' at the end of Self::File counts + match self.path_str() { + Ok(str) => str.len() + 1 + *SUN_PATH_OFFSET, + Err(_) => 0, + } + } + + pub fn path_str(&self) -> Result<&str> { + match self { + Self::File(_, unix_path) => Ok(&unix_path.path_str()), + Self::Abstract(path) => Ok(&path), + Self::Unnamed => return_errno!(EINVAL, "can't get path name for unnamed socket"), + } + } + + pub fn to_c_storage(&self) -> (libc::sockaddr_storage, usize) { + let c_un_addr = self.to_c(); + c_un_addr.to_c_storage() + } + + pub fn to_raw(&self) -> RawAddr { + let (storage, addr_len) = self.to_c_storage(); + RawAddr::from_c_storage(&storage, addr_len) + } + + fn to_c(&self) -> (libc::sockaddr_un, usize) { + const FAMILY_LEN: usize = std::mem::size_of::(); + + let mut addr: libc::sockaddr_un = unsafe { mem::zeroed() }; + addr.sun_family = Domain::LOCAL as libc::sa_family_t; + + let addr_len = match self { + Self::Unnamed => FAMILY_LEN, + Self::File(_, unix_path) => { + let path_str = unix_path.path_str(); + let buf_len = path_str.len(); + /// addr is initialized to all zeros and try_from_raw guarentees + /// unix_path length is shorter than sun_path, so sun_path here + /// will always have a null terminator + addr.sun_path[..buf_len] + .copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) }); + buf_len + *SUN_PATH_OFFSET + 1 + } + Self::Abstract(path_str) => { + addr.sun_path[0] = 0; + let buf_len = path_str.len() + 1; + addr.sun_path[1..buf_len] + .copy_from_slice(unsafe { &*(path_str.as_bytes() as *const _ as *const [i8]) }); + buf_len + *SUN_PATH_OFFSET + } + }; + + (addr, addr_len) + } +} + +impl Default for UnixAddr { + fn default() -> Self { + UnixAddr::Unnamed + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UnixPath { + inner: PathBuf, + /// Holds the cwd when a relative path is created + cwd: Option, +} + +impl UnixPath { + pub fn new(path: &str) -> Self { + let inner = PathBuf::from(path); + let is_absolute = inner.is_absolute(); + Self { + inner: inner, + cwd: if is_absolute { + None + } else { + let thread = current!(); + let fs = thread.fs().read().unwrap(); + let cwd = fs.cwd().to_owned(); + + Some(cwd) + }, + } + } + + pub fn absolute(&self) -> String { + let path_str = self.path_str(); + if self.inner.is_absolute() { + path_str.to_string() + } else { + let mut prefix = self.cwd.as_ref().unwrap().clone(); + if prefix.ends_with("/") { + prefix.push_str(path_str); + } else { + prefix.push_str("/"); + prefix.push_str(path_str); + } + prefix + } + } + + pub fn path_str(&self) -> &str { + self.inner.to_str().unwrap() + } +} diff --git a/src/libos/src/net/socket/util/any_addr.rs b/src/libos/src/net/socket/util/any_addr.rs new file mode 100644 index 00000000..a8a0492f --- /dev/null +++ b/src/libos/src/net/socket/util/any_addr.rs @@ -0,0 +1,113 @@ +use std::mem::{self, MaybeUninit}; + +use crate::net::socket::Domain; +use crate::prelude::*; + +use super::{Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, RawAddr, UnixAddr}; +use num_enum::IntoPrimitive; +use std::path::Path; + +#[derive(Clone, Debug)] +pub enum AnyAddr { + Ipv4(Ipv4SocketAddr), + Ipv6(Ipv6SocketAddr), + Unix(UnixAddr), + Raw(RawAddr), + Unspec, +} + +impl AnyAddr { + pub fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Result { + let any_addr = match c_addr.ss_family as _ { + libc::AF_INET => { + let ipv4_addr = Ipv4SocketAddr::from_c_storage(c_addr, c_addr_len)?; + Self::Ipv4(ipv4_addr) + } + libc::AF_INET6 => { + let ipv6_addr = Ipv6SocketAddr::from_c_storage(c_addr, c_addr_len)?; + Self::Ipv6(ipv6_addr) + } + libc::AF_UNSPEC => Self::Unspec, + libc::AF_UNIX | libc::AF_LOCAL => { + let unix_addr = UnixAddr::from_c_storage(c_addr, c_addr_len)?; + Self::Unix(unix_addr) + } + _ => { + let raw_addr = RawAddr::from_c_storage(c_addr, c_addr_len); + Self::Raw(raw_addr) + } + }; + Ok(any_addr) + } + + pub fn to_c_storage(&self) -> (libc::sockaddr_storage, usize) { + match self { + Self::Ipv4(ipv4_addr) => ipv4_addr.to_c_storage(), + Self::Ipv6(ipv6_addr) => ipv6_addr.to_c_storage(), + Self::Unix(unix_addr) => unix_addr.to_c_storage(), + Self::Raw(raw_addr) => raw_addr.to_c_storage(), + Self::Unspec => { + let mut sockaddr_storage = + unsafe { MaybeUninit::::uninit().assume_init() }; + sockaddr_storage.ss_family = libc::AF_UNSPEC as _; + (sockaddr_storage, mem::size_of::()) + } + } + } + + pub fn to_raw(&self) -> RawAddr { + match self { + Self::Ipv4(ipv4_addr) => ipv4_addr.to_raw(), + Self::Ipv6(ipv6_addr) => ipv6_addr.to_raw(), + Self::Unix(unix_addr) => unix_addr.to_raw(), + Self::Raw(raw_addr) => *raw_addr, + Self::Unspec => { + let mut sockaddr_storage = + unsafe { MaybeUninit::::uninit().assume_init() }; + sockaddr_storage.ss_family = libc::AF_UNSPEC as _; + RawAddr::from_c_storage(&sockaddr_storage, mem::size_of::()) + } + } + } + + pub fn to_unix(&self) -> Result<&UnixAddr> { + match self { + Self::Unix(unix_addr) => Ok(unix_addr), + _ => return_errno!(EAFNOSUPPORT, "not unix address"), + } + } + + pub fn as_ipv4(&self) -> Option<&Ipv4SocketAddr> { + match self { + Self::Ipv4(ipv4_addr) => Some(ipv4_addr), + _ => None, + } + } + + pub fn to_ipv4(&self) -> Result<&Ipv4SocketAddr> { + match self { + Self::Ipv4(ipv4_addr) => Ok(ipv4_addr), + _ => return_errno!(EAFNOSUPPORT, "not ipv4 address"), + } + } + + pub fn to_ipv6(&self) -> Result<&Ipv6SocketAddr> { + match self { + Self::Ipv6(ipv6_addr) => Ok(ipv6_addr), + _ => return_errno!(EAFNOSUPPORT, "not ipv6 address"), + } + } + + pub fn is_unspec(&self) -> bool { + match self { + Self::Unspec => true, + _ => false, + } + } + + pub fn as_slice(&self) -> &[u8] { + let (storage, len) = self.to_c_storage(); + let addr = &storage as *const _ as *const _; + unsafe { std::slice::from_raw_parts(addr as *const u8, len) } + } +} diff --git a/src/libos/src/net/socket/address_family.rs b/src/libos/src/net/socket/util/domain.rs similarity index 88% rename from src/libos/src/net/socket/address_family.rs rename to src/libos/src/net/socket/util/domain.rs index 06a39a12..d80655b3 100644 --- a/src/libos/src/net/socket/address_family.rs +++ b/src/libos/src/net/socket/util/domain.rs @@ -1,10 +1,11 @@ use super::*; +use num_enum::{IntoPrimitive, TryFromPrimitive}; // The protocol family generally is the same as the address family -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, IntoPrimitive, TryFromPrimitive)] #[repr(u16)] #[allow(non_camel_case_types)] -pub enum AddressFamily { +pub enum Domain { UNSPEC = 0, LOCAL = 1, /* Hide the families with the same number @@ -60,7 +61,7 @@ pub enum AddressFamily { MAX = 45, } -impl AddressFamily { +impl Domain { pub fn try_from(af: u16) -> Result { if af >= Self::MAX as u16 { return_errno!(EINVAL, "Unknown address family"); diff --git a/src/libos/src/net/socket/util/flags.rs b/src/libos/src/net/socket/util/flags.rs new file mode 100644 index 00000000..6ba10e0b --- /dev/null +++ b/src/libos/src/net/socket/util/flags.rs @@ -0,0 +1,39 @@ +use bitflags::bitflags; + +// Flags to use when sending data through a socket +bitflags! { + pub struct SendFlags: i32 { + const MSG_OOB = 0x01; // Sends out-of-band data on sockets + const MSG_DONTROUTE = 0x04; // Don't use a gateway to send out the packet + const MSG_DONTWAIT = 0x40; // Nonblocking io + const MSG_EOR = 0x80; // End of record + const MSG_CONFIRM = 0x0800; // Confirm path validity + const MSG_NOSIGNAL = 0x4000; // Do not generate SIGPIPE + const MSG_MORE = 0x8000; // Sender will send more + } +} + +// Flags to use when receiving data through a socket +bitflags! { + pub struct RecvFlags: i32 { + const MSG_OOB = 0x01; // Recv out-of-band data + const MSG_PEEK = 0x02; // Return data without removing that + const MSG_TRUNC = 0x20; // Return the real length of the packet or datagram + const MSG_DONTWAIT = 0x40; // Nonblocking io + const MSG_WAITALL = 0x0100; // Wait for a full request + const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue + // recvmsg only + const MSG_CMSG_CLOEXEC = 0x40000000; // Set close_on_exec for file descriptor received through SCM_RIGHTS + } +} + +bitflags! { + pub struct MsgFlags: i32 { + const MSG_OOB = 0x01; // Expedited or out-of-band data was received + const MSG_CTRUNC = 0x08; // Some control data was discarded + const MSG_TRUNC = 0x20; // The trailing portion of a datagram was discarded + const MSG_EOR = 0x80; // End of record + const MSG_ERRQUEUE = 0x2000; // Fetch message from error queue + const MSG_NOTIFICATION = 0x8000; // Only applicable to SCTP socket + } +} diff --git a/src/libos/src/net/socket/iovs.rs b/src/libos/src/net/socket/util/iovs.rs similarity index 100% rename from src/libos/src/net/socket/iovs.rs rename to src/libos/src/net/socket/util/iovs.rs diff --git a/src/libos/src/net/socket/util/mod.rs b/src/libos/src/net/socket/util/mod.rs new file mode 100644 index 00000000..31f6dcf8 --- /dev/null +++ b/src/libos/src/net/socket/util/mod.rs @@ -0,0 +1,27 @@ +use crate::prelude::*; +use crate::untrusted::{ + SliceAsMutPtrAndLen, SliceAsPtrAndLen, UntrustedSlice, UntrustedSliceAlloc, +}; +use std; + +mod addr; +mod any_addr; +mod domain; +mod flags; +mod iovs; +mod msg; +mod protocol; +mod shutdown; +mod r#type; + +pub use self::addr::{ + Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, RawAddr, UnixAddr, +}; +pub use self::any_addr::AnyAddr; +pub use self::domain::Domain; +pub use self::flags::{MsgFlags, RecvFlags, SendFlags}; +pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; +pub use self::msg::{CMessages, CmsgData}; +pub use self::protocol::SocketProtocol; +pub use self::r#type::Type; +pub use self::shutdown::Shutdown; diff --git a/src/libos/src/net/socket/util/msg.rs b/src/libos/src/net/socket/util/msg.rs new file mode 100644 index 00000000..90b10751 --- /dev/null +++ b/src/libos/src/net/socket/util/msg.rs @@ -0,0 +1,115 @@ +/// Socket message and its flags. +use super::*; + +/// This struct is used to iterate through the control messages. +/// +/// `cmsghdr` is a C struct for ancillary data object information of a unix socket. +pub struct CMessages<'a> { + buffer: &'a [u8], + current: Option<&'a libc::cmsghdr>, +} + +impl<'a> Iterator for CMessages<'a> { + type Item = CmsgData<'a>; + + fn next(&mut self) -> Option { + let cmsg = unsafe { + let mut msg: libc::msghdr = core::mem::zeroed(); + msg.msg_control = self.buffer.as_ptr() as *mut _; + msg.msg_controllen = self.buffer.len() as _; + + let cmsg = if let Some(current) = self.current { + libc::CMSG_NXTHDR(&msg, current) + } else { + libc::CMSG_FIRSTHDR(&msg) + }; + cmsg.as_ref()? + }; + + self.current = Some(cmsg); + CmsgData::try_from_cmsghdr(cmsg) + } +} + +impl<'a> CMessages<'a> { + pub fn from_bytes(msg_control: &'a mut [u8]) -> Self { + Self { + buffer: msg_control, + current: None, + } + } +} + +/// Control message data of variable type. The data resides next to `cmsghdr`. +pub enum CmsgData<'a> { + ScmRights(ScmRights<'a>), + ScmCredentials, +} + +impl<'a> CmsgData<'a> { + /// Create an `CmsgData::ScmRights` variant. + /// + /// # Safety + /// + /// `data` must contain a valid control message and the control message must be type of + /// `SOL_SOCKET` and level of `SCM_RIGHTS`. + unsafe fn as_rights(data: &'a mut [u8]) -> Self { + let scm_rights = ScmRights { data }; + CmsgData::ScmRights(scm_rights) + } + + /// Create an `CmsgData::ScmCredentials` variant. + /// + /// # Safety + /// + /// `data` must contain a valid control message and the control message must be type of + /// `SOL_SOCKET` and level of `SCM_CREDENTIALS`. + unsafe fn as_credentials(_data: &'a [u8]) -> Self { + CmsgData::ScmCredentials + } + + fn try_from_cmsghdr(cmsg: &'a libc::cmsghdr) -> Option { + unsafe { + let cmsg_len_zero = libc::CMSG_LEN(0) as usize; + let data_len = (*cmsg).cmsg_len as usize - cmsg_len_zero; + let data = libc::CMSG_DATA(cmsg); + let data = core::slice::from_raw_parts_mut(data, data_len); + + match (*cmsg).cmsg_level { + libc::SOL_SOCKET => match (*cmsg).cmsg_type { + libc::SCM_RIGHTS => Some(CmsgData::as_rights(data)), + libc::SCM_CREDENTIALS => Some(CmsgData::as_credentials(data)), + _ => None, + }, + _ => None, + } + } + } +} + +/// The data unit of this control message is file descriptor(s). +/// +/// The level is equal to `SOL_SOCKET` and the type is equal to `SCM_RIGHTS`. +pub struct ScmRights<'a> { + data: &'a mut [u8], +} + +impl<'a> ScmRights<'a> { + /// Iterate and reassign each fd in data buffer, given a reassignment function. + pub fn iter_and_reassign_fds(&mut self, reassign_fd_fn: F) + where + F: Fn(FileDesc) -> FileDesc, + { + for fd_bytes in self.data.chunks_exact_mut(core::mem::size_of::()) { + let old_fd = FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap()); + let reassigned_fd = reassign_fd_fn(old_fd); + fd_bytes.copy_from_slice(&reassigned_fd.to_ne_bytes()); + } + } + + pub fn iter_fds(&self) -> impl Iterator + '_ { + self.data + .chunks_exact(core::mem::size_of::()) + .map(|fd_bytes| FileDesc::from_ne_bytes(fd_bytes.try_into().unwrap())) + } +} diff --git a/src/libos/src/net/socket/util/protocol.rs b/src/libos/src/net/socket/util/protocol.rs new file mode 100644 index 00000000..00ff1a7c --- /dev/null +++ b/src/libos/src/net/socket/util/protocol.rs @@ -0,0 +1,34 @@ +use num_enum::{IntoPrimitive, TryFromPrimitive}; + +/* Standard well-defined IP protocols. */ +#[allow(non_camel_case_types)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, IntoPrimitive, TryFromPrimitive)] +#[repr(i32)] +pub enum SocketProtocol { + IPPROTO_IP = 0, /* Dummy protocol for TCP. */ + IPPROTO_ICMP = 1, /* Internet Control Message Protocol. */ + IPPROTO_IGMP = 2, /* Internet Group Management Protocol. */ + IPPROTO_IPIP = 4, /* IPIP tunnels (older KA9Q tunnels use 94). */ + IPPROTO_TCP = 6, /* Transmission Control Protocol. */ + IPPROTO_EGP = 8, /* Exterior Gateway Protocol. */ + IPPROTO_PUP = 12, /* PUP protocol. */ + IPPROTO_UDP = 17, /* User Datagram Protocol. */ + IPPROTO_IDP = 22, /* XNS IDP protocol. */ + IPPROTO_TP = 29, /* SO Transport Protocol Class 4. */ + IPPROTO_DCCP = 33, /* Datagram Congestion Control Protocol. */ + IPPROTO_IPV6 = 41, /* IPv6 header. */ + IPPROTO_RSVP = 46, /* Reservation Protocol. */ + IPPROTO_GRE = 47, /* General Routing Encapsulation. */ + IPPROTO_ESP = 50, /* encapsulating security payload. */ + IPPROTO_AH = 51, /* authentication header. */ + IPPROTO_MTP = 92, /* Multicast Transport Protocol. */ + IPPROTO_BEETPH = 94, /* IP option pseudo header for BEET. */ + IPPROTO_ENCAP = 98, /* Encapsulation Header. */ + IPPROTO_PIM = 103, /* Protocol Independent Multicast. */ + IPPROTO_COMP = 108, /* Compression Header Protocol. */ + IPPROTO_SCTP = 132, /* Stream Control Transmission Protocol. */ + IPPROTO_UDPLITE = 136, /* UDP-Lite protocol. */ + IPPROTO_MPLS = 137, /* MPLS in IP. */ + IPPROTO_RAW = 255, /* Raw IP packets. */ + IPPROTO_MAX, +} diff --git a/src/libos/src/net/socket/util/shutdown.rs b/src/libos/src/net/socket/util/shutdown.rs new file mode 100644 index 00000000..e0166f7f --- /dev/null +++ b/src/libos/src/net/socket/util/shutdown.rs @@ -0,0 +1,34 @@ +use crate::prelude::*; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u32)] +pub enum Shutdown { + Read = 0, + Write = 1, + Both = 2, +} + +impl Shutdown { + pub fn from_c(c_val: u32) -> Result { + match c_val { + 0 => Ok(Self::Read), + 1 => Ok(Self::Write), + 2 => Ok(Self::Both), + _ => return_errno!(EINVAL, "invalid how"), + } + } + + pub fn to_c(&self) -> u32 { + *self as u32 + } + + pub fn should_shut_read(&self) -> bool { + // a slightly more efficient check than using two equality comparions + self.to_c() % 2 == 0 + } + + pub fn should_shut_write(&self) -> bool { + // a slightly more efficient check than using two equality comparions + self.to_c() >= 1 + } +} diff --git a/src/libos/src/net/socket/util/type.rs b/src/libos/src/net/socket/util/type.rs new file mode 100644 index 00000000..1c861d23 --- /dev/null +++ b/src/libos/src/net/socket/util/type.rs @@ -0,0 +1,15 @@ +use crate::prelude::*; +use num_enum::{IntoPrimitive, TryFromPrimitive}; + +/// A network type. +#[derive(Clone, Copy, Debug, Eq, PartialEq, IntoPrimitive, TryFromPrimitive)] +#[repr(i32)] +pub enum Type { + STREAM = libc::SOCK_STREAM, + DGRAM = libc::SOCK_DGRAM, + RAW = libc::SOCK_RAW, + RDM = libc::SOCK_RDM, + SEQPACKET = libc::SOCK_SEQPACKET, + DCCP = libc::SOCK_DCCP, + PACKET = libc::SOCK_PACKET, +}