From 070a024c0d5ebb16d3151db3902a73e80cb87cfd Mon Sep 17 00:00:00 2001 From: ClawSeven Date: Fri, 24 May 2024 17:18:50 +0800 Subject: [PATCH] [Libos] Refactor io_uring and related module implementions --- src/libos/src/events/waiter/edge.rs | 9 +- src/libos/src/events/waiter_queue.rs | 8 +- .../fs/file_ops/ioctl/builtin/get_ifconf.rs | 109 ++--- .../fs/file_ops/ioctl/builtin/get_ifreq.rs | 56 +-- .../src/fs/file_ops/ioctl/builtin/mod.rs | 1 - src/libos/src/fs/pipe.rs | 2 +- src/libos/src/io_uring.rs | 10 +- src/libos/src/net/mod.rs | 10 +- src/libos/src/net/socket/host/mod.rs | 23 +- src/libos/src/net/socket/host/recv.rs | 20 +- src/libos/src/net/socket/host/send.rs | 14 +- src/libos/src/net/socket/host/socket_file.rs | 2 +- src/libos/src/net/socket/mod.rs | 19 +- .../src/net/socket/sockopt/get_output.rs | 9 +- .../src/net/socket/sockopt/get_sockbuf.rs | 4 +- src/libos/src/net/socket/sockopt/mod.rs | 4 +- src/libos/src/net/socket/sockopt/set.rs | 12 +- .../src/net/socket/sockopt/set_sockbuf.rs | 40 +- src/libos/src/net/socket/sockopt/timeout.rs | 30 +- src/libos/src/net/socket/unix/mod.rs | 8 +- .../src/net/socket/unix/stream/stream.rs | 11 +- .../src/net/socket/uring/common/common.rs | 10 +- .../src/net/socket/uring/datagram/generic.rs | 8 +- src/libos/src/net/socket/uring/socket_file.rs | 16 +- src/libos/src/net/socket/uring/stream/mod.rs | 22 +- .../uring/stream/states/connected/recv.rs | 5 +- .../uring/stream/states/connected/send.rs | 9 +- .../net/socket/uring/stream/states/init.rs | 2 +- .../net/socket/uring/stream/states/listen.rs | 6 +- src/libos/src/net/socket/util/addr/ipv4.rs | 7 +- src/libos/src/net/socket/util/addr/ipv6.rs | 6 +- src/libos/src/net/socket/util/addr/mod.rs | 2 +- .../src/net/socket/util/addr/raw_addr.rs | 10 +- .../src/net/socket/util/addr/unix_addr.rs | 6 +- src/libos/src/net/socket/util/any_addr.rs | 17 +- src/libos/src/net/socket/util/flags.rs | 17 + src/libos/src/net/socket/util/mod.rs | 6 +- src/libos/src/net/socket/util/type.rs | 2 +- src/libos/src/net/syscalls.rs | 395 ++++++++---------- src/libos/src/prelude.rs | 3 +- src/libos/src/process/mod.rs | 1 - src/libos/src/util/sync/mutex.rs | 45 +- 42 files changed, 478 insertions(+), 518 deletions(-) diff --git a/src/libos/src/events/waiter/edge.rs b/src/libos/src/events/waiter/edge.rs index 41e94bbd..79c12dd1 100644 --- a/src/libos/src/events/waiter/edge.rs +++ b/src/libos/src/events/waiter/edge.rs @@ -26,7 +26,11 @@ impl Synchronizer for EdgeSync { return Ok(()); } loop { - self.host_eventfd.poll(timeout)?; + if let Err(error) = self.host_eventfd.poll(timeout) { + self.state.store(INIT, Ordering::Relaxed); + return Err(error); + } + if self .state .compare_exchange(NOTIFIED, INIT, Ordering::Acquire, Ordering::Acquire) @@ -47,8 +51,7 @@ impl Synchronizer for EdgeSync { // Need to change timeout from `Option<&mut Duration>` to `&mut Option` // so that the Rust compiler is happy about using the variable in a loop. let ret = self.host_eventfd.poll_mut(remain.as_mut()); - // Wait for something to happen, assuming it's still set to PARKED. - // futex_wait(&self.state, PARKED, Some(timeout)); + // Wait for something to happen, assuming it's still set to NOTIFIED. // This is not just a store, because we need to establish a // release-acquire ordering with unpark(). if self.state.swap(INIT, Ordering::Acquire) == NOTIFIED { diff --git a/src/libos/src/events/waiter_queue.rs b/src/libos/src/events/waiter_queue.rs index 8cb068db..1dc8f1f2 100644 --- a/src/libos/src/events/waiter_queue.rs +++ b/src/libos/src/events/waiter_queue.rs @@ -27,7 +27,7 @@ use crate::prelude::*; /// Although it is safer to use AcqRel,here using `Release` would be enough. pub struct WaiterQueue { count: AtomicUsize, - wakers: SgxMutex>>, + wakers: Mutex>>, } impl WaiterQueue { @@ -35,7 +35,7 @@ impl WaiterQueue { pub fn new() -> Self { Self { count: AtomicUsize::new(0), - wakers: SgxMutex::new(VecDeque::new()), + wakers: Mutex::new(VecDeque::new()), } } @@ -54,7 +54,7 @@ impl WaiterQueue { pub fn reset_and_enqueue(&self, waiter: &Waiter) { waiter.reset(); - let mut wakers = self.wakers.lock().unwrap(); + let mut wakers = self.wakers.lock(); self.count.fetch_add(1, Ordering::Release); wakers.push_back(waiter.waker()); } @@ -78,7 +78,7 @@ impl WaiterQueue { // Dequeue wakers let to_wake = { - let mut wakers = self.wakers.lock().unwrap(); + let mut wakers = self.wakers.lock(); let max_count = max_count.min(wakers.len()); let to_wake: Vec> = wakers.drain(..max_count).collect(); self.count.fetch_sub(to_wake.len(), Ordering::Release); diff --git a/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs b/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs index c507a248..d8ceaf81 100644 --- a/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs +++ b/src/libos/src/fs/file_ops/ioctl/builtin/get_ifconf.rs @@ -36,7 +36,8 @@ impl GetIfConf { } let mut if_conf = self.to_raw_ifconf(); - get_ifconf_by_host(fd, &mut if_conf)?; + GetIfConf::get_ifconf_by_host(fd, &mut if_conf)?; + self.set_len(if_conf.ifc_len as usize); Ok(()) } @@ -78,59 +79,59 @@ impl GetIfConf { }, } } + + fn get_ifconf_by_host(fd: FileDesc, if_conf: &mut IfConf) -> Result<()> { + const SIOCGIFCONF: u32 = 0x8912; + + extern "C" { + // Used to ioctl arguments with pointer members. + // + // Before the call the area the pointers points to should be assembled into + // one continuous memory block. Then the block is repacked to ioctl arguments + // in the ocall implementation in host. + // + // ret: holds the return value of ioctl in host + // fd: the host fd for the device + // cmd_num: request number of the ioctl + // buf: the data to exchange with host + // len: the size of the buf + // recv_len: accepts transferred data length when buf is used to get data from host + // + fn occlum_ocall_ioctl_repack( + ret: *mut i32, + fd: i32, + cmd_num: i32, + buf: *const u8, + len: i32, + recv_len: *mut i32, + ) -> sgx_types::sgx_status_t; + } + + try_libc!({ + let mut recv_len: i32 = 0; + let mut retval: i32 = 0; + let status = occlum_ocall_ioctl_repack( + &mut retval as *mut i32, + fd as _, + SIOCGIFCONF as _, + if_conf.ifc_buf, + if_conf.ifc_len, + &mut recv_len as *mut i32, + ); + assert!(status == sgx_types::sgx_status_t::SGX_SUCCESS); + + // If ifc_req is NULL, SIOCGIFCONF returns the necessary buffer + // size in bytes for receiving all available addresses in ifc_len + // which is irrelevant to the orginal ifc_len. + if !if_conf.ifc_buf.is_null() { + assert!(if_conf.ifc_len >= recv_len); + } + if_conf.ifc_len = recv_len; + retval + }); + + Ok(()) + } } impl IoctlCmd for GetIfConf {} - -const SIOCGIFCONF: u32 = 0x8912; - -fn get_ifconf_by_host(fd: FileDesc, if_conf: &mut IfConf) -> Result<()> { - extern "C" { - // Used to ioctl arguments with pointer members. - // - // Before the call the area the pointers points to should be assembled into - // one continuous memory block. Then the block is repacked to ioctl arguments - // in the ocall implementation in host. - // - // ret: holds the return value of ioctl in host - // fd: the host fd for the device - // cmd_num: request number of the ioctl - // buf: the data to exchange with host - // len: the size of the buf - // recv_len: accepts transferred data length when buf is used to get data from host - // - fn occlum_ocall_ioctl_repack( - ret: *mut i32, - fd: i32, - cmd_num: i32, - buf: *const u8, - len: i32, - recv_len: *mut i32, - ) -> sgx_types::sgx_status_t; - } - - try_libc!({ - let mut recv_len: i32 = 0; - let mut retval: i32 = 0; - let status = occlum_ocall_ioctl_repack( - &mut retval as *mut i32, - fd as _, - SIOCGIFCONF as _, - if_conf.ifc_buf, - if_conf.ifc_len, - &mut recv_len as *mut i32, - ); - assert!(status == sgx_types::sgx_status_t::SGX_SUCCESS); - - // If ifc_req is NULL, SIOCGIFCONF returns the necessary buffer - // size in bytes for receiving all available addresses in ifc_len - // which is irrelevant to the orginal ifc_len. - if !if_conf.ifc_buf.is_null() { - assert!(if_conf.ifc_len >= recv_len); - } - if_conf.ifc_len = recv_len; - retval - }); - - Ok(()) -} diff --git a/src/libos/src/fs/file_ops/ioctl/builtin/get_ifreq.rs b/src/libos/src/fs/file_ops/ioctl/builtin/get_ifreq.rs index cb9b7d91..772d9979 100644 --- a/src/libos/src/fs/file_ops/ioctl/builtin/get_ifreq.rs +++ b/src/libos/src/fs/file_ops/ioctl/builtin/get_ifreq.rs @@ -34,40 +34,40 @@ impl GetIfReqWithRawCmd { pub fn execute(&mut self, fd: FileDesc) -> Result<()> { let input_if_req = self.inner.input(); - let output_if_req = get_ifreq_by_host(fd, self.raw_cmd, input_if_req)?; + let output_if_req = GetIfReqWithRawCmd::get_ifreq_by_host(fd, self.raw_cmd, input_if_req)?; self.inner.set_output(output_if_req); Ok(()) } -} -fn get_ifreq_by_host(fd: FileDesc, cmd: u32, req: &IfReq) -> Result { - let mut if_req: IfReq = req.clone(); - try_libc!({ - let mut retval: i32 = 0; - extern "C" { - pub fn occlum_ocall_ioctl( - ret: *mut i32, - fd: c_int, - request: c_int, - arg: *mut c_void, - len: size_t, - ) -> sgx_types::sgx_status_t; - } + fn get_ifreq_by_host(fd: FileDesc, cmd: u32, req: &IfReq) -> Result { + let mut if_req: IfReq = req.clone(); + try_libc!({ + let mut retval: i32 = 0; + extern "C" { + pub fn occlum_ocall_ioctl( + ret: *mut i32, + fd: c_int, + request: c_int, + arg: *mut c_void, + len: size_t, + ) -> sgx_types::sgx_status_t; + } - use libc::{c_int, c_void, size_t}; - use occlum_ocall_ioctl as do_ioctl; + use libc::{c_int, c_void, size_t}; + use occlum_ocall_ioctl as do_ioctl; - let status = do_ioctl( - &mut retval as *mut i32, - fd as i32, - cmd as i32, - &mut if_req as *mut IfReq as *mut c_void, - std::mem::size_of::(), - ); - assert!(status == sgx_types::sgx_status_t::SGX_SUCCESS); - retval - }); - Ok(if_req) + let status = do_ioctl( + &mut retval as *mut i32, + fd as i32, + cmd as i32, + &mut if_req as *mut IfReq as *mut c_void, + std::mem::size_of::(), + ); + assert!(status == sgx_types::sgx_status_t::SGX_SUCCESS); + retval + }); + Ok(if_req) + } } impl IoctlCmd for GetIfReqWithRawCmd {} diff --git a/src/libos/src/fs/file_ops/ioctl/builtin/mod.rs b/src/libos/src/fs/file_ops/ioctl/builtin/mod.rs index 3ac75d57..efe8a054 100644 --- a/src/libos/src/fs/file_ops/ioctl/builtin/mod.rs +++ b/src/libos/src/fs/file_ops/ioctl/builtin/mod.rs @@ -214,7 +214,6 @@ pub use self::set_close_on_exec::*; pub use self::set_nonblocking::SetNonBlocking; pub use self::termios::*; pub use self::winsize::*; -pub use net::socket::sockopt::SetSockOptRawCmd; mod get_ifconf; mod get_ifreq; diff --git a/src/libos/src/fs/pipe.rs b/src/libos/src/fs/pipe.rs index 6e768a36..3d38fa97 100644 --- a/src/libos/src/fs/pipe.rs +++ b/src/libos/src/fs/pipe.rs @@ -112,7 +112,7 @@ impl File for PipeReader { fn ioctl(&self, cmd: &mut dyn IoctlCmd) -> Result<()> { match_ioctl_cmd_auto_error!(cmd, { cmd : GetReadBufLen => { - let read_buf_len = self.consumer.ready_len(); + let read_buf_len = self.get_ready_len().min(std::i32::MAX as usize) as i32; cmd.set_output(read_buf_len as _); }, }); diff --git a/src/libos/src/io_uring.rs b/src/libos/src/io_uring.rs index e629ad32..351d8112 100644 --- a/src/libos/src/io_uring.rs +++ b/src/libos/src/io_uring.rs @@ -104,13 +104,11 @@ impl UringSet { // Sum registered socket let total_socket_num = map .values() - .fold(0, |acc, state| acc + state.registered_num) - + 1; + .fold(0, |acc, state| acc + state.registered_num); // Determine the number of available io_uring let uring_num = (total_socket_num / SOCKET_THRESHOLD_PER_URING) + 1; - let existed_uring_num = self.running_uring_num.load(Ordering::Relaxed); - assert!(existed_uring_num <= uring_num); - existed_uring_num < uring_num + + running_uring_num < uring_num }; if should_build_uring { @@ -134,7 +132,7 @@ impl UringSet { // Link the file to the io_uring instance with the least load. let (mut uring, mut state) = map .iter_mut() - .min_by_key(|(_, &mut state)| state.registered_num) + .min_by_key(|(_, state)| state.registered_num) .unwrap(); // Re-select io_uring instance with least task load diff --git a/src/libos/src/net/mod.rs b/src/libos/src/net/mod.rs index 04bb2172..942cab91 100644 --- a/src/libos/src/net/mod.rs +++ b/src/libos/src/net/mod.rs @@ -7,13 +7,17 @@ pub use self::io_multiplexing::{ PollEventFlags, PollFd, THREAD_NOTIFIERS, }; pub use self::socket::{ - socketpair, unix_socket, AsUnixSocket, Domain, HostSocket, HostSocketType, Iovs, IovsMut, - RawAddr, SliceAsLibcIovec, UnixAddr, + mmsghdr, socketpair, unix_socket, Addr, AnyAddr, AsUnixSocket, Domain, GetAcceptConnCmd, + GetDomainCmd, GetErrorCmd, GetOutputAsBytes, GetPeerNameCmd, GetRecvBufSizeCmd, + GetRecvTimeoutCmd, GetSendBufSizeCmd, GetSendTimeoutCmd, GetSockOptRawCmd, GetTypeCmd, + HostSocket, HostSocketType, Iovs, IovsMut, RecvFlags, SendFlags, SetRecvBufSizeCmd, + SetRecvTimeoutCmd, SetSendBufSizeCmd, SetSendTimeoutCmd, SetSockOptRawCmd, Shutdown, + SliceAsLibcIovec, SockAddr, SockOptName, SocketFile, SocketType, UnixAddr, UringSocketType, }; pub use self::syscalls::*; mod io_multiplexing; -pub(crate) mod socket; +mod socket; mod syscalls; pub use self::syscalls::*; diff --git a/src/libos/src/net/socket/host/mod.rs b/src/libos/src/net/socket/host/mod.rs index d5230ee3..b142f7f0 100644 --- a/src/libos/src/net/socket/host/mod.rs +++ b/src/libos/src/net/socket/host/mod.rs @@ -26,7 +26,7 @@ pub struct HostSocket { impl HostSocket { pub fn new( domain: Domain, - socket_type: Type, + socket_type: SocketType, socket_flags: SocketFlags, protocol: i32, ) -> Result { @@ -49,7 +49,7 @@ impl HostSocket { }) } - pub fn bind(&self, addr: &RawAddr) -> Result<()> { + pub fn bind(&self, addr: &SockAddr) -> Result<()> { let (addr_ptr, addr_len) = addr.as_ptr_and_len(); let ret = try_libc!(libc::ocall::bind( @@ -65,8 +65,8 @@ impl HostSocket { Ok(()) } - pub fn accept(&self, flags: SocketFlags) -> Result<(Self, Option)> { - let mut sockaddr = RawAddr::default(); + pub fn accept(&self, flags: SocketFlags) -> Result<(Self, Option)> { + let mut sockaddr = SockAddr::default(); let mut addr_len = sockaddr.len(); let raw_host_fd = try_libc!(libc::ocall::accept4( @@ -86,8 +86,8 @@ impl HostSocket { Ok((HostSocket::from_host_fd(host_fd)?, addr_option)) } - pub fn addr(&self) -> Result { - let mut sockaddr = RawAddr::default(); + pub fn addr(&self) -> Result { + let mut sockaddr = SockAddr::default(); let mut addr_len = sockaddr.len(); try_libc!(libc::ocall::getsockname( self.raw_host_fd() as i32, @@ -99,8 +99,8 @@ impl HostSocket { Ok(sockaddr) } - pub fn peer_addr(&self) -> Result { - let mut sockaddr = RawAddr::default(); + pub fn peer_addr(&self) -> Result { + let mut sockaddr = SockAddr::default(); let mut addr_len = sockaddr.len(); try_libc!(libc::ocall::getpeername( self.raw_host_fd() as i32, @@ -112,7 +112,7 @@ impl HostSocket { Ok(sockaddr) } - pub fn connect(&self, addr: &Option) -> Result<()> { + pub fn connect(&self, addr: Option<&SockAddr>) -> Result<()> { debug!("connect: host_fd: {}, addr {:?}", self.raw_host_fd(), addr); let (addr_ptr, addr_len) = if let Some(sock_addr) = addr { @@ -133,14 +133,13 @@ impl HostSocket { &self, buf: &[u8], flags: SendFlags, - addr_option: &Option, + addr_option: Option, ) -> Result { let bufs = vec![buf]; self.sendmsg(&bufs, flags, addr_option, None) } - pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { - let mut sockaddr = RawAddr::default(); + pub fn recvfrom(&self, buf: &mut [u8], flags: RecvFlags) -> Result<(usize, Option)> { let mut bufs = vec![buf]; let (bytes_recv, recv_addr, _, _) = self.recvmsg(&mut bufs, flags, None)?; diff --git a/src/libos/src/net/socket/host/recv.rs b/src/libos/src/net/socket/host/recv.rs index edacdb37..1b8b10c4 100644 --- a/src/libos/src/net/socket/host/recv.rs +++ b/src/libos/src/net/socket/host/recv.rs @@ -12,7 +12,7 @@ impl HostSocket { data: &mut [&mut [u8]], flags: RecvFlags, control: Option<&mut [u8]>, - ) -> Result<(usize, Option, MsgFlags, usize)> { + ) -> Result<(usize, Option, MsgFlags, usize)> { let current = current!(); let data_length = data.iter().map(|s| s.len()).sum(); let mut ocall_alloc; @@ -54,10 +54,10 @@ impl HostSocket { data: &mut [UntrustedSlice], flags: RecvFlags, mut control: Option<&mut [u8]>, - ) -> Result<(usize, Option, MsgFlags, usize)> { + ) -> Result<(usize, Option, MsgFlags, usize)> { // Prepare the arguments for OCall let host_fd = self.raw_host_fd() as i32; - let mut addr = RawAddr::default(); + let mut addr = SockAddr::default(); let mut msg_name = addr.as_mut_ptr(); let mut msg_namelen = addr.len(); let mut msg_namelen_recvd = 0_u32; @@ -122,16 +122,16 @@ impl HostSocket { }; let msg_namelen_recvd = msg_namelen_recvd as usize; - let raw_addr = if msg_namelen_recvd == 0 { - None - } else { - addr.set_len(msg_namelen_recvd)?; - Some(addr) - }; + let raw_addr = (msg_namelen_recvd != 0).then(|| { + addr.set_len(msg_namelen_recvd); + addr + }); + + let addr = raw_addr.map(|addr| AnyAddr::Raw(addr)); assert!(msg_namelen_recvd <= msg_namelen); assert!(msg_controllen_recvd <= msg_controllen); - Ok((bytes_recvd, raw_addr, flags_recvd, msg_controllen_recvd)) + Ok((bytes_recvd, addr, flags_recvd, msg_controllen_recvd)) } } diff --git a/src/libos/src/net/socket/host/send.rs b/src/libos/src/net/socket/host/send.rs index 0df6f73c..df1af688 100644 --- a/src/libos/src/net/socket/host/send.rs +++ b/src/libos/src/net/socket/host/send.rs @@ -2,14 +2,14 @@ use super::*; impl HostSocket { pub fn send(&self, buf: &[u8], flags: SendFlags) -> Result { - self.sendto(buf, flags, &None) + self.sendto(buf, flags, None) } pub fn sendmsg( &self, data: &[&[u8]], flags: SendFlags, - addr: &Option, + addr: Option, control: Option<&[u8]>, ) -> Result { let current = current!(); @@ -34,8 +34,14 @@ impl HostSocket { bufs }; - let name = addr.as_ref().map(|raw_addr| raw_addr.as_slice()); - self.do_sendmsg_untrusted_data(&u_data, flags, name, control) + let raw_addr = addr.map(|addr| addr.to_raw()); + + self.do_sendmsg_untrusted_data( + &u_data, + flags, + raw_addr.as_ref().map(|addr| addr.as_slice()), + control, + ) } fn do_sendmsg_untrusted_data( diff --git a/src/libos/src/net/socket/host/socket_file.rs b/src/libos/src/net/socket/host/socket_file.rs index 3fe9598b..700ed3d5 100644 --- a/src/libos/src/net/socket/host/socket_file.rs +++ b/src/libos/src/net/socket/host/socket_file.rs @@ -41,7 +41,7 @@ impl File for HostSocket { } fn writev(&self, bufs: &[&[u8]]) -> Result { - self.sendmsg(bufs, SendFlags::empty(), &None, None) + self.sendmsg(bufs, SendFlags::empty(), None, None) } fn seek(&self, pos: SeekFrom) -> Result { diff --git a/src/libos/src/net/socket/mod.rs b/src/libos/src/net/socket/mod.rs index 37c72801..7270c7e6 100644 --- a/src/libos/src/net/socket/mod.rs +++ b/src/libos/src/net/socket/mod.rs @@ -1,15 +1,22 @@ use super::*; mod host; -pub(crate) mod sockopt; +mod sockopt; mod unix; -pub(crate) mod uring; -pub(crate) mod util; +mod uring; +mod util; pub use self::host::{HostSocket, HostSocketType}; pub use self::unix::{socketpair, unix_socket, AsUnixSocket}; pub use self::util::{ - Addr, AnyAddr, CMessages, CSockAddr, CmsgData, Domain, Iovs, IovsMut, Ipv4Addr, Ipv4SocketAddr, - Ipv6SocketAddr, MsgFlags, RawAddr, RecvFlags, SendFlags, Shutdown, SliceAsLibcIovec, - SocketProtocol, Type, UnixAddr, + mmsghdr, Addr, AnyAddr, CMessages, CSockAddr, CmsgData, Domain, Iovs, IovsMut, Ipv4Addr, + Ipv4SocketAddr, Ipv6SocketAddr, MsgFlags, RecvFlags, SendFlags, Shutdown, SliceAsLibcIovec, + SockAddr, SocketFlags, SocketProtocol, SocketType, UnixAddr, }; +pub use sockopt::{ + GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetOutputAsBytes, GetPeerNameCmd, + GetRecvBufSizeCmd, GetRecvTimeoutCmd, GetSendBufSizeCmd, GetSendTimeoutCmd, GetSockOptRawCmd, + GetTypeCmd, SetRecvBufSizeCmd, SetRecvTimeoutCmd, SetSendBufSizeCmd, SetSendTimeoutCmd, + SetSockOptRawCmd, SockOptName, +}; +pub use uring::{socket_file::SocketFile, UringSocketType}; diff --git a/src/libos/src/net/socket/sockopt/get_output.rs b/src/libos/src/net/socket/sockopt/get_output.rs index dab0bb53..01dd34bd 100644 --- a/src/libos/src/net/socket/sockopt/get_output.rs +++ b/src/libos/src/net/socket/sockopt/get_output.rs @@ -1,12 +1,11 @@ use super::{GetRecvTimeoutCmd, GetSendTimeoutCmd}; use super::{ - GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetPeerNameCmd, GetRcvBufSizeCmd, - GetSndBufSizeCmd, GetSockOptRawCmd, GetTypeCmd, + GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetPeerNameCmd, GetRecvBufSizeCmd, + GetSendBufSizeCmd, GetSockOptRawCmd, GetTypeCmd, }; use libc::timeval; -use std::time::Duration; use crate::prelude::*; @@ -60,7 +59,7 @@ impl GetOutputAsBytes for GetTypeCmd { } } -impl GetOutputAsBytes for GetSndBufSizeCmd { +impl GetOutputAsBytes for GetSendBufSizeCmd { fn get_output_as_bytes(&self) -> Option<&[u8]> { self.output().map(|val_ref| unsafe { std::slice::from_raw_parts( @@ -71,7 +70,7 @@ impl GetOutputAsBytes for GetSndBufSizeCmd { } } -impl GetOutputAsBytes for GetRcvBufSizeCmd { +impl GetOutputAsBytes for GetRecvBufSizeCmd { fn get_output_as_bytes(&self) -> Option<&[u8]> { self.output().map(|val_ref| unsafe { std::slice::from_raw_parts( diff --git a/src/libos/src/net/socket/sockopt/get_sockbuf.rs b/src/libos/src/net/socket/sockopt/get_sockbuf.rs index 951c6ba1..bd2ff8dd 100644 --- a/src/libos/src/net/socket/sockopt/get_sockbuf.rs +++ b/src/libos/src/net/socket/sockopt/get_sockbuf.rs @@ -1,7 +1,7 @@ crate::impl_ioctl_cmd! { - pub struct GetSndBufSizeCmd {} + pub struct GetSendBufSizeCmd {} } crate::impl_ioctl_cmd! { - pub struct GetRcvBufSizeCmd {} + pub struct GetRecvBufSizeCmd {} } diff --git a/src/libos/src/net/socket/sockopt/mod.rs b/src/libos/src/net/socket/sockopt/mod.rs index da663807..56b291ed 100644 --- a/src/libos/src/net/socket/sockopt/mod.rs +++ b/src/libos/src/net/socket/sockopt/mod.rs @@ -16,10 +16,10 @@ pub use get_domain::GetDomainCmd; pub use get_error::GetErrorCmd; pub use get_output::*; pub use get_peername::{AddrStorage, GetPeerNameCmd}; -pub use get_sockbuf::{GetRcvBufSizeCmd, GetSndBufSizeCmd}; +pub use get_sockbuf::{GetRecvBufSizeCmd, GetSendBufSizeCmd}; pub use get_type::GetTypeCmd; pub use set::{setsockopt_by_host, SetSockOptRawCmd}; -pub use set_sockbuf::{SetRcvBufSizeCmd, SetSndBufSizeCmd}; +pub use set_sockbuf::{SetRecvBufSizeCmd, SetSendBufSizeCmd}; pub use timeout::{ timeout_to_timeval, GetRecvTimeoutCmd, GetSendTimeoutCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd, }; diff --git a/src/libos/src/net/socket/sockopt/set.rs b/src/libos/src/net/socket/sockopt/set.rs index 8118885f..b84a8484 100644 --- a/src/libos/src/net/socket/sockopt/set.rs +++ b/src/libos/src/net/socket/sockopt/set.rs @@ -2,14 +2,16 @@ use crate::{fs::IoctlCmd, prelude::*}; use libc::ocall::setsockopt as do_setsockopt; #[derive(Debug)] -pub struct SetSockOptRawCmd { +pub struct SetSockOptRawCmd<'a> { level: i32, optname: i32, - optval: &'static [u8], + optval: &'a [u8], } -impl SetSockOptRawCmd { - pub fn new(level: i32, optname: i32, optval: &'static [u8]) -> Self { +impl IoctlCmd for SetSockOptRawCmd<'static> {} + +impl<'a> SetSockOptRawCmd<'a> { + pub fn new(level: i32, optname: i32, optval: &'a [u8]) -> SetSockOptRawCmd<'a> { Self { level, optname, @@ -23,8 +25,6 @@ impl SetSockOptRawCmd { } } -impl IoctlCmd for SetSockOptRawCmd {} - pub fn setsockopt_by_host(fd: FileDesc, level: i32, optname: i32, optval: &[u8]) -> Result<()> { try_libc!(do_setsockopt( fd as _, diff --git a/src/libos/src/net/socket/sockopt/set_sockbuf.rs b/src/libos/src/net/socket/sockopt/set_sockbuf.rs index eddae439..8928b39f 100644 --- a/src/libos/src/net/socket/sockopt/set_sockbuf.rs +++ b/src/libos/src/net/socket/sockopt/set_sockbuf.rs @@ -1,23 +1,18 @@ use super::set::setsockopt_by_host; use crate::{fs::IoctlCmd, prelude::*}; -#[derive(Debug)] -pub struct SetSndBufSizeCmd { - buf_size: usize, +crate::impl_ioctl_cmd! { + pub struct SetSendBufSizeCmd {} } -impl SetSndBufSizeCmd { - pub fn new(buf_size: usize) -> Self { - Self { buf_size } - } - - pub fn buf_size(&self) -> usize { - self.buf_size - } +crate::impl_ioctl_cmd! { + pub struct SetRecvBufSizeCmd {} +} +impl SetSendBufSizeCmd { pub fn update_host(&self, fd: FileDesc) -> Result<()> { // The buf size for host call should be divided by 2 because the value will be doubled by host kernel. - let host_call_buf_size = (self.buf_size / 2).to_ne_bytes(); + let host_call_buf_size = (self.input / 2).to_ne_bytes(); // Setting SO_SNDBUF for host socket needs to respect /proc/sys/net/core/wmem_max. Thus, the value might be different on host, but it is fine. setsockopt_by_host( @@ -29,25 +24,10 @@ impl SetSndBufSizeCmd { } } -impl IoctlCmd for SetSndBufSizeCmd {} - -#[derive(Debug)] -pub struct SetRcvBufSizeCmd { - buf_size: usize, -} - -impl SetRcvBufSizeCmd { - pub fn new(buf_size: usize) -> Self { - Self { buf_size } - } - - pub fn buf_size(&self) -> usize { - self.buf_size - } - +impl SetRecvBufSizeCmd { pub fn update_host(&self, fd: FileDesc) -> Result<()> { // The buf size for host call should be divided by 2 because the value will be doubled by host kernel. - let host_call_buf_size = (self.buf_size / 2).to_ne_bytes(); + let host_call_buf_size = (self.input / 2).to_ne_bytes(); // Setting SO_RCVBUF for host socket needs to respect /proc/sys/net/core/rmem_max. Thus, the value might be different on host, but it is fine. setsockopt_by_host( @@ -58,5 +38,3 @@ impl SetRcvBufSizeCmd { ) } } - -impl IoctlCmd for SetRcvBufSizeCmd {} diff --git a/src/libos/src/net/socket/sockopt/timeout.rs b/src/libos/src/net/socket/sockopt/timeout.rs index 954a5001..8b9475f7 100644 --- a/src/libos/src/net/socket/sockopt/timeout.rs +++ b/src/libos/src/net/socket/sockopt/timeout.rs @@ -3,34 +3,12 @@ use crate::prelude::*; use libc::{suseconds_t, time_t}; use std::time::Duration; -#[derive(Debug)] -pub struct SetSendTimeoutCmd(Duration); - -impl IoctlCmd for SetSendTimeoutCmd {} - -impl SetSendTimeoutCmd { - pub fn new(timeout: Duration) -> Self { - Self(timeout) - } - - pub fn timeout(&self) -> &Duration { - &self.0 - } +crate::impl_ioctl_cmd! { + pub struct SetSendTimeoutCmd {} } -#[derive(Debug)] -pub struct SetRecvTimeoutCmd(Duration); - -impl IoctlCmd for SetRecvTimeoutCmd {} - -impl SetRecvTimeoutCmd { - pub fn new(timeout: Duration) -> Self { - Self(timeout) - } - - pub fn timeout(&self) -> &Duration { - &self.0 - } +crate::impl_ioctl_cmd! { + pub struct SetRecvTimeoutCmd {} } crate::impl_ioctl_cmd! { diff --git a/src/libos/src/net/socket/unix/mod.rs b/src/libos/src/net/socket/unix/mod.rs index f8260232..2080307a 100644 --- a/src/libos/src/net/socket/unix/mod.rs +++ b/src/libos/src/net/socket/unix/mod.rs @@ -5,12 +5,12 @@ mod stream; pub use self::stream::Stream; //TODO: rewrite this file when a new kind of uds is added -pub fn unix_socket(socket_type: Type, flags: SocketFlags, protocol: i32) -> Result { +pub fn unix_socket(socket_type: SocketType, flags: SocketFlags, protocol: i32) -> Result { if protocol != 0 && protocol != Domain::LOCAL as i32 { return_errno!(EPROTONOSUPPORT, "protocol is not supported"); } - if socket_type == Type::STREAM { + if socket_type == SocketType::STREAM { Ok(Stream::new(flags)) } else { return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); @@ -18,7 +18,7 @@ pub fn unix_socket(socket_type: Type, flags: SocketFlags, protocol: i32) -> Resu } pub fn socketpair( - socket_type: Type, + socket_type: SocketType, flags: SocketFlags, protocol: i32, ) -> Result<(Stream, Stream)> { @@ -26,7 +26,7 @@ pub fn socketpair( return_errno!(EPROTONOSUPPORT, "protocol is not supported"); } - if socket_type == Type::STREAM { + if socket_type == SocketType::STREAM { Stream::socketpair(flags) } else { return_errno!(ESOCKTNOSUPPORT, "only stream type is supported"); diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs index fc1711e2..94ca7b47 100644 --- a/src/libos/src/net/socket/unix/stream/stream.rs +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -72,7 +72,10 @@ impl Stream { return_errno!(ENOTCONN, "the socket is not connected"); } - pub fn bind(&self, addr: &mut UnixAddr) -> Result<()> { + pub fn bind(&self, addr: &UnixAddr) -> Result<()> { + let mut unix_addr = addr.clone(); + let addr = &mut unix_addr; + if let UnixAddr::File(inode_num, path) = addr { // create the corresponding file in the fs and fill Addr with its inode let corresponding_inode_num = { @@ -218,7 +221,7 @@ impl Stream { } // TODO: handle flags - pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: &Option) -> Result { + pub fn sendto(&self, buf: &[u8], flags: SendFlags, addr: Option<&UnixAddr>) -> Result { self.write(buf) } @@ -255,7 +258,7 @@ impl Stream { bufs: &mut [&mut [u8]], flags: RecvFlags, control: Option<&mut [u8]>, - ) -> Result<(usize, usize)> { + ) -> Result<(usize, Option, MsgFlags, usize)> { if !flags.is_empty() { warn!("unsupported flags: {:?}", flags); } @@ -288,7 +291,7 @@ impl Stream { 0 }; - Ok((data_len, control_len)) + Ok((data_len, None, MsgFlags::empty(), control_len)) } /// perform shutdown on the socket. diff --git a/src/libos/src/net/socket/uring/common/common.rs b/src/libos/src/net/socket/uring/common/common.rs index 864ee2cc..c7676f7f 100644 --- a/src/libos/src/net/socket/uring/common/common.rs +++ b/src/libos/src/net/socket/uring/common/common.rs @@ -18,7 +18,7 @@ use crate::prelude::*; /// The common parts of all stream sockets. pub struct Common { host_fd: FileDesc, - type_: Type, + type_: SocketType, nonblocking: AtomicBool, is_closed: AtomicBool, pollee: Pollee, @@ -30,7 +30,7 @@ pub struct Common { } impl Common { - pub fn new(type_: Type, nonblocking: bool, protocol: Option) -> Result { + pub fn new(type_: SocketType, nonblocking: bool, protocol: Option) -> Result { let domain_c = A::domain() as libc::c_int; let type_c = type_ as libc::c_int; let protocol = protocol.unwrap_or(0) as libc::c_int; @@ -56,11 +56,11 @@ impl Common { }) } - pub fn new_pair(sock_type: Type, nonblocking: bool) -> Result<(Self, Self)> { + pub fn new_pair(sock_type: SocketType, nonblocking: bool) -> Result<(Self, Self)> { return_errno!(EINVAL, "Unix is unsupported"); } - pub fn with_host_fd(host_fd: FileDesc, type_: Type, nonblocking: bool) -> Self { + pub fn with_host_fd(host_fd: FileDesc, type_: SocketType, nonblocking: bool) -> Self { let nonblocking = AtomicBool::new(nonblocking); let is_closed = AtomicBool::new(false); let pollee = Pollee::new(IoEvents::empty()); @@ -90,7 +90,7 @@ impl Common { self.host_fd } - pub fn type_(&self) -> Type { + pub fn type_(&self) -> SocketType { self.type_ } diff --git a/src/libos/src/net/socket/uring/datagram/generic.rs b/src/libos/src/net/socket/uring/datagram/generic.rs index f9b5d3af..d7675b34 100644 --- a/src/libos/src/net/socket/uring/datagram/generic.rs +++ b/src/libos/src/net/socket/uring/datagram/generic.rs @@ -20,7 +20,7 @@ pub struct DatagramSocket { impl DatagramSocket { pub fn new(nonblocking: bool) -> Result { - let common = Arc::new(Common::new(Type::DGRAM, nonblocking, None)?); + let common = Arc::new(Common::new(SocketType::DGRAM, nonblocking, None)?); let state = RwLock::new(State::new()); let sender = Sender::new(common.clone()); let receiver = Receiver::new(common.clone()); @@ -33,7 +33,7 @@ impl DatagramSocket { } pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { - let (common1, common2) = Common::new_pair(Type::DGRAM, nonblocking)?; + let (common1, common2) = Common::new_pair(SocketType::DGRAM, nonblocking)?; let socket1 = Self::new_connected(common1); let socket2 = Self::new_connected(common2); Ok((socket1, socket2)) @@ -321,10 +321,10 @@ impl DatagramSocket { cmd.execute(self.host_fd())?; }, cmd: SetRecvTimeoutCmd => { - self.set_recv_timeout(*cmd.timeout()); + self.set_recv_timeout(*cmd.input()); }, cmd: SetSendTimeoutCmd => { - self.set_send_timeout(*cmd.timeout()); + self.set_send_timeout(*cmd.input()); }, cmd: GetRecvTimeoutCmd => { let timeval = timeout_to_timeval(self.recv_timeout()); diff --git a/src/libos/src/net/socket/uring/socket_file.rs b/src/libos/src/net/socket/uring/socket_file.rs index d7ee0595..9039bf0f 100644 --- a/src/libos/src/net/socket/uring/socket_file.rs +++ b/src/libos/src/net/socket/uring/socket_file.rs @@ -102,10 +102,10 @@ impl SocketFile { apply_fn_on_any_socket!(&self.socket, |socket| { socket.ioctl(cmd) }) } - pub fn get_type(&self) -> Type { + pub fn get_type(&self) -> SocketType { match self.socket { - AnySocket::Ipv4Stream(_) | AnySocket::Ipv6Stream(_) => Type::STREAM, - AnySocket::Ipv4Datagram(_) | AnySocket::Ipv6Datagram(_) => Type::DGRAM, + AnySocket::Ipv4Stream(_) | AnySocket::Ipv6Stream(_) => SocketType::STREAM, + AnySocket::Ipv4Datagram(_) | AnySocket::Ipv6Datagram(_) => SocketType::DGRAM, } } } @@ -115,11 +115,11 @@ impl SocketFile { pub fn new( domain: Domain, protocol: SocketProtocol, - socket_type: Type, + socket_type: SocketType, nonblocking: bool, ) -> Result { match socket_type { - Type::STREAM => { + SocketType::STREAM => { if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_TCP { return_errno!(EPROTONOSUPPORT, "Protocol not supported"); @@ -140,7 +140,7 @@ impl SocketFile { let new_self = Self { socket: any_socket }; Ok(new_self) } - Type::DGRAM => { + SocketType::DGRAM => { if protocol != SocketProtocol::IPPROTO_IP && protocol != SocketProtocol::IPPROTO_UDP { return_errno!(EPROTONOSUPPORT, "Protocol not supported"); @@ -161,7 +161,7 @@ impl SocketFile { let new_self = Self { socket: any_socket }; Ok(new_self) } - Type::RAW => { + SocketType::RAW => { return_errno!(EINVAL, "RAW socket not supported"); } _ => { @@ -210,7 +210,7 @@ impl SocketFile { } } - pub fn bind(&self, addr: &mut AnyAddr) -> Result<()> { + pub fn bind(&self, addr: &AnyAddr) -> Result<()> { match &self.socket { AnySocket::Ipv4Stream(ipv4_stream) => { let ip_addr = addr.to_ipv4()?; diff --git a/src/libos/src/net/socket/uring/stream/mod.rs b/src/libos/src/net/socket/uring/stream/mod.rs index fa4e070c..de6c4616 100644 --- a/src/libos/src/net/socket/uring/stream/mod.rs +++ b/src/libos/src/net/socket/uring/stream/mod.rs @@ -57,7 +57,7 @@ impl StreamSocket { } pub fn new_pair(nonblocking: bool) -> Result<(Self, Self)> { - let (common1, common2) = Common::new_pair(Type::STREAM, nonblocking)?; + let (common1, common2) = Common::new_pair(SocketType::STREAM, nonblocking)?; let connected1 = ConnectedStream::new(Arc::new(common1)); let connected2 = ConnectedStream::new(Arc::new(common2)); let socket1 = Self::new_connected(connected1); @@ -373,10 +373,10 @@ impl StreamSocket { cmd.execute(self.host_fd())?; }, cmd: SetRecvTimeoutCmd => { - self.set_recv_timeout(*cmd.timeout()); + self.set_recv_timeout(*cmd.input()); }, cmd: SetSendTimeoutCmd => { - self.set_send_timeout(*cmd.timeout()); + self.set_send_timeout(*cmd.input()); }, cmd: GetRecvTimeoutCmd => { let timeval = timeout_to_timeval(self.recv_timeout()); @@ -386,21 +386,21 @@ impl StreamSocket { let timeval = timeout_to_timeval(self.send_timeout()); cmd.set_output(timeval); }, - cmd: SetSndBufSizeCmd => { + cmd: SetSendBufSizeCmd => { cmd.update_host(self.host_fd())?; - let buf_size = cmd.buf_size(); + let buf_size = *cmd.input(); self.set_kernel_send_buf_size(buf_size); }, - cmd: SetRcvBufSizeCmd => { + cmd: SetRecvBufSizeCmd => { cmd.update_host(self.host_fd())?; - let buf_size = cmd.buf_size(); + let buf_size = *cmd.input(); self.set_kernel_recv_buf_size(buf_size); }, - cmd: GetSndBufSizeCmd => { + cmd: GetSendBufSizeCmd => { let buf_size = SEND_BUF_SIZE.load(Ordering::Relaxed); cmd.set_output(buf_size); }, - cmd: GetRcvBufSizeCmd => { + cmd: GetRecvBufSizeCmd => { let buf_size = RECV_BUF_SIZE.load(Ordering::Relaxed); cmd.set_output(buf_size); }, @@ -454,6 +454,8 @@ impl StreamSocket { } fn set_kernel_send_buf_size(&self, buf_size: usize) { + // Setting the minimal buf_size to 128 Kbytes + let buf_size = (128 * 1024 + 1).max(buf_size); let state = self.state.read().unwrap(); match &*state { State::Init(_) | State::Listen(_) | State::Connect(_) => { @@ -467,6 +469,8 @@ impl StreamSocket { } fn set_kernel_recv_buf_size(&self, buf_size: usize) { + // Setting the minimal buf_size to 128 Kbytes + let buf_size = (128 * 1024 + 1).max(buf_size); let state = self.state.read().unwrap(); match &*state { State::Init(_) | State::Listen(_) | State::Connect(_) => { diff --git a/src/libos/src/net/socket/uring/stream/states/connected/recv.rs b/src/libos/src/net/socket/uring/stream/states/connected/recv.rs index 3a82cdfe..37afb6fd 100644 --- a/src/libos/src/net/socket/uring/stream/states/connected/recv.rs +++ b/src/libos/src/net/socket/uring/stream/states/connected/recv.rs @@ -191,7 +191,6 @@ impl ConnectedStream { // Init the callback invoked upon the completion of the async recv let stream = self.clone(); let complete_fn = move |retval: i32| { - // let mut inner = stream.receiver.inner.lock().unwrap(); let mut inner = stream.receiver.inner.lock(); trace!("recv request complete with retval: {:?}", retval); @@ -232,7 +231,9 @@ impl ConnectedStream { // ready to read. stream.common.pollee().add_events(Events::IN); - stream.do_recv(&mut inner); + if !stream.receiver.need_update() { + stream.do_recv(&mut inner); + } }; // Generate the async recv request diff --git a/src/libos/src/net/socket/uring/stream/states/connected/send.rs b/src/libos/src/net/socket/uring/stream/states/connected/send.rs index 6b0e8088..3226f4f5 100644 --- a/src/libos/src/net/socket/uring/stream/states/connected/send.rs +++ b/src/libos/src/net/socket/uring/stream/states/connected/send.rs @@ -183,7 +183,14 @@ impl ConnectedStream { inner.fatal = Some(errno); stream.common.set_errno(errno); - stream.common.pollee().add_events(Events::ERR); + + let events = if errno == ENOTCONN || errno == ECONNRESET || errno == ECONNREFUSED { + Events::HUP | Events::OUT | Events::ERR + } else { + Events::ERR + }; + + stream.common.pollee().add_events(events); return; } assert!(retval != 0); diff --git a/src/libos/src/net/socket/uring/stream/states/init.rs b/src/libos/src/net/socket/uring/stream/states/init.rs index 5edcecbe..0ef6078e 100644 --- a/src/libos/src/net/socket/uring/stream/states/init.rs +++ b/src/libos/src/net/socket/uring/stream/states/init.rs @@ -15,7 +15,7 @@ struct Inner { impl InitStream { pub fn new(nonblocking: bool) -> Result> { - let common = Arc::new(Common::new(Type::STREAM, nonblocking, None)?); + let common = Arc::new(Common::new(SocketType::STREAM, nonblocking, None)?); common.pollee().add_events(IoEvents::HUP | IoEvents::OUT); let inner = Mutex::new(Inner::new()); let new_self = Self { common, inner }; diff --git a/src/libos/src/net/socket/uring/stream/states/listen.rs b/src/libos/src/net/socket/uring/stream/states/listen.rs index 24a91da0..8278729b 100644 --- a/src/libos/src/net/socket/uring/stream/states/listen.rs +++ b/src/libos/src/net/socket/uring/stream/states/listen.rs @@ -121,7 +121,11 @@ impl ListenerStream { self.initiate_async_accepts(inner); let common = { - let common = Arc::new(Common::with_host_fd(accepted_fd, Type::STREAM, nonblocking)); + let common = Arc::new(Common::with_host_fd( + accepted_fd, + SocketType::STREAM, + nonblocking, + )); common.set_peer_addr(&accepted_addr); common }; diff --git a/src/libos/src/net/socket/util/addr/ipv4.rs b/src/libos/src/net/socket/util/addr/ipv4.rs index 29b818ce..d4ddc81a 100644 --- a/src/libos/src/net/socket/util/addr/ipv4.rs +++ b/src/libos/src/net/socket/util/addr/ipv4.rs @@ -1,7 +1,8 @@ +use crate::net::{Addr, Domain}; use std::any::Any; use std::fmt::{self, Debug}; -use super::{Addr, CSockAddr, Domain, RawAddr}; +use super::{CSockAddr, SockAddr}; use crate::prelude::*; /// An IPv4 socket address, consisting of an IPv4 address and a port. @@ -70,9 +71,9 @@ impl Ipv4SocketAddr { } } - pub fn to_raw(&self) -> RawAddr { + pub fn to_raw(&self) -> SockAddr { let (storage, len) = self.to_c_storage(); - RawAddr::from_c_storage(&storage, len) + SockAddr::from_c_storage(&storage, len) } pub fn ip(&self) -> &Ipv4Addr { diff --git a/src/libos/src/net/socket/util/addr/ipv6.rs b/src/libos/src/net/socket/util/addr/ipv6.rs index 764c0a54..a84903fa 100644 --- a/src/libos/src/net/socket/util/addr/ipv6.rs +++ b/src/libos/src/net/socket/util/addr/ipv6.rs @@ -1,7 +1,7 @@ use std::any::Any; use std::fmt::Debug; -use super::RawAddr; +use super::SockAddr; use super::{Addr, CSockAddr, Domain}; use crate::prelude::*; use libc::in6_addr; @@ -85,9 +85,9 @@ impl Ipv6SocketAddr { } } - pub fn to_raw(&self) -> RawAddr { + pub fn to_raw(&self) -> SockAddr { let (storage, len) = self.to_c_storage(); - RawAddr::from_c_storage(&storage, len) + SockAddr::from_c_storage(&storage, len) } pub fn ip(&self) -> &Ipv6Addr { diff --git a/src/libos/src/net/socket/util/addr/mod.rs b/src/libos/src/net/socket/util/addr/mod.rs index 66e08b2d..c09e0190 100644 --- a/src/libos/src/net/socket/util/addr/mod.rs +++ b/src/libos/src/net/socket/util/addr/mod.rs @@ -38,7 +38,7 @@ pub trait Addr: Clone + Debug + Default + PartialEq + Send + Sync { pub use self::c_sock_addr::CSockAddr; pub use self::ipv4::{Ipv4Addr, Ipv4SocketAddr}; pub use self::ipv6::{Ipv6Addr, Ipv6SocketAddr}; -pub use self::raw_addr::RawAddr; +pub use self::raw_addr::SockAddr; pub use self::unix_addr::UnixAddr; #[cfg(test)] diff --git a/src/libos/src/net/socket/util/addr/raw_addr.rs b/src/libos/src/net/socket/util/addr/raw_addr.rs index 3f355c14..7fde0f6b 100644 --- a/src/libos/src/net/socket/util/addr/raw_addr.rs +++ b/src/libos/src/net/socket/util/addr/raw_addr.rs @@ -2,22 +2,22 @@ use super::*; use std::*; #[derive(Copy, Clone)] -pub struct RawAddr { +pub struct SockAddr { storage: libc::sockaddr_storage, len: usize, } // TODO: add more fields -impl fmt::Debug for RawAddr { +impl fmt::Debug for SockAddr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RawAddr") + f.debug_struct("SockAddr") .field("family", &Domain::try_from(self.storage.ss_family).unwrap()) .field("len", &self.len) .finish() } } -impl RawAddr { +impl SockAddr { pub fn from_c_storage(c_addr: &libc::sockaddr_storage, c_addr_len: usize) -> Self { Self { storage: *c_addr, @@ -118,7 +118,7 @@ impl RawAddr { } } -impl Default for RawAddr { +impl Default for SockAddr { fn default() -> Self { let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; Self { diff --git a/src/libos/src/net/socket/util/addr/unix_addr.rs b/src/libos/src/net/socket/util/addr/unix_addr.rs index 4a4ac271..9249c51f 100644 --- a/src/libos/src/net/socket/util/addr/unix_addr.rs +++ b/src/libos/src/net/socket/util/addr/unix_addr.rs @@ -106,7 +106,7 @@ impl UnixAddr { /// The '/0' at the end of Self::File counts match self.path_str() { Ok(str) => str.len() + 1 + *SUN_PATH_OFFSET, - Err(_) => 0, + Err(_) => std::mem::size_of::(), } } @@ -123,9 +123,9 @@ impl UnixAddr { c_un_addr.to_c_storage() } - pub fn to_raw(&self) -> RawAddr { + pub fn to_raw(&self) -> SockAddr { let (storage, addr_len) = self.to_c_storage(); - RawAddr::from_c_storage(&storage, addr_len) + SockAddr::from_c_storage(&storage, addr_len) } fn to_c(&self) -> (libc::sockaddr_un, usize) { diff --git a/src/libos/src/net/socket/util/any_addr.rs b/src/libos/src/net/socket/util/any_addr.rs index a8a0492f..cec5d42d 100644 --- a/src/libos/src/net/socket/util/any_addr.rs +++ b/src/libos/src/net/socket/util/any_addr.rs @@ -3,7 +3,7 @@ use std::mem::{self, MaybeUninit}; use crate::net::socket::Domain; use crate::prelude::*; -use super::{Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, RawAddr, UnixAddr}; +use super::{Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, SockAddr, UnixAddr}; use num_enum::IntoPrimitive; use std::path::Path; @@ -12,7 +12,7 @@ pub enum AnyAddr { Ipv4(Ipv4SocketAddr), Ipv6(Ipv6SocketAddr), Unix(UnixAddr), - Raw(RawAddr), + Raw(SockAddr), Unspec, } @@ -33,7 +33,7 @@ impl AnyAddr { Self::Unix(unix_addr) } _ => { - let raw_addr = RawAddr::from_c_storage(c_addr, c_addr_len); + let raw_addr = SockAddr::from_c_storage(c_addr, c_addr_len); Self::Raw(raw_addr) } }; @@ -55,7 +55,7 @@ impl AnyAddr { } } - pub fn to_raw(&self) -> RawAddr { + pub fn to_raw(&self) -> SockAddr { match self { Self::Ipv4(ipv4_addr) => ipv4_addr.to_raw(), Self::Ipv6(ipv6_addr) => ipv6_addr.to_raw(), @@ -65,7 +65,7 @@ impl AnyAddr { let mut sockaddr_storage = unsafe { MaybeUninit::::uninit().assume_init() }; sockaddr_storage.ss_family = libc::AF_UNSPEC as _; - RawAddr::from_c_storage(&sockaddr_storage, mem::size_of::()) + SockAddr::from_c_storage(&sockaddr_storage, mem::size_of::()) } } } @@ -77,13 +77,6 @@ impl AnyAddr { } } - pub fn as_ipv4(&self) -> Option<&Ipv4SocketAddr> { - match self { - Self::Ipv4(ipv4_addr) => Some(ipv4_addr), - _ => None, - } - } - pub fn to_ipv4(&self) -> Result<&Ipv4SocketAddr> { match self { Self::Ipv4(ipv4_addr) => Ok(ipv4_addr), diff --git a/src/libos/src/net/socket/util/flags.rs b/src/libos/src/net/socket/util/flags.rs index 6ba10e0b..6a0d26ab 100644 --- a/src/libos/src/net/socket/util/flags.rs +++ b/src/libos/src/net/socket/util/flags.rs @@ -1,4 +1,6 @@ use bitflags::bitflags; +use sgx_trts::libc; +use std::ffi::c_uint; // Flags to use when sending data through a socket bitflags! { @@ -37,3 +39,18 @@ bitflags! { const MSG_NOTIFICATION = 0x8000; // Only applicable to SCTP socket } } + +// Flags to use when creating a new socket +bitflags! { + pub struct SocketFlags: i32 { + const SOCK_NONBLOCK = 0x800; + const SOCK_CLOEXEC = 0x80000; + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct mmsghdr { + pub msg_hdr: libc::msghdr, + pub msg_len: c_uint, +} diff --git a/src/libos/src/net/socket/util/mod.rs b/src/libos/src/net/socket/util/mod.rs index 31f6dcf8..a9358e14 100644 --- a/src/libos/src/net/socket/util/mod.rs +++ b/src/libos/src/net/socket/util/mod.rs @@ -15,13 +15,13 @@ mod shutdown; mod r#type; pub use self::addr::{ - Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, RawAddr, UnixAddr, + Addr, CSockAddr, Ipv4Addr, Ipv4SocketAddr, Ipv6SocketAddr, SockAddr, UnixAddr, }; pub use self::any_addr::AnyAddr; pub use self::domain::Domain; -pub use self::flags::{MsgFlags, RecvFlags, SendFlags}; +pub use self::flags::{mmsghdr, MsgFlags, RecvFlags, SendFlags, SocketFlags}; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::msg::{CMessages, CmsgData}; pub use self::protocol::SocketProtocol; -pub use self::r#type::Type; +pub use self::r#type::SocketType; pub use self::shutdown::Shutdown; diff --git a/src/libos/src/net/socket/util/type.rs b/src/libos/src/net/socket/util/type.rs index 1c861d23..7b1e426f 100644 --- a/src/libos/src/net/socket/util/type.rs +++ b/src/libos/src/net/socket/util/type.rs @@ -4,7 +4,7 @@ use num_enum::{IntoPrimitive, TryFromPrimitive}; /// A network type. #[derive(Clone, Copy, Debug, Eq, PartialEq, IntoPrimitive, TryFromPrimitive)] #[repr(i32)] -pub enum Type { +pub enum SocketType { STREAM = libc::SOCK_STREAM, DGRAM = libc::SOCK_DGRAM, RAW = libc::SOCK_RAW, diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index 1a91158f..664f973b 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -1,5 +1,4 @@ -use super::socket::{MsgFlags, SocketProtocol}; -use super::{socket::uring::socket_file::SocketFile, *}; +use super::socket::{mmsghdr, MsgFlags, SocketFlags, SocketProtocol}; use atomic::Ordering; use core::f32::consts::E; @@ -17,15 +16,6 @@ use std::convert::TryFrom; use time::{timespec_t, timeval_t}; use util::mem_util::from_user; -use super::socket::sockopt::{ - GetAcceptConnCmd, GetDomainCmd, GetErrorCmd, GetOutputAsBytes, GetPeerNameCmd, - GetRcvBufSizeCmd, GetRecvTimeoutCmd, GetSendTimeoutCmd, GetSndBufSizeCmd, GetSockOptRawCmd, - GetTypeCmd, SetRcvBufSizeCmd, SetRecvTimeoutCmd, SetSendTimeoutCmd, SetSndBufSizeCmd, - SetSockOptRawCmd, SockOptName, -}; -use super::socket::uring::UringSocketType; -use super::socket::AnyAddr; - use super::*; use crate::fs::StatusFlags; @@ -36,20 +26,13 @@ use crate::prelude::*; const SOMAXCONN: u32 = 4096; const SOCONN_DEFAULT: u32 = 16; -#[repr(C)] -#[derive(Copy, Clone)] -pub struct mmsghdr { - pub msg_hdr: libc::msghdr, - pub msg_len: c_uint, -} - pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result { let domain = Domain::try_from(domain as u16)?; let flags = SocketFlags::from_bits_truncate(socket_type); let type_bits = socket_type & !flags.bits(); let socket_type = - Type::try_from(type_bits).map_err(|_| errno!(EINVAL, "invalid socket type"))?; + SocketType::try_from(type_bits).map_err(|_| errno!(EINVAL, "invalid socket type"))?; debug!( "socket domain: {:?}, type: {:?}, protocol: {:?}", @@ -64,7 +47,8 @@ pub fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result Result Result { - let addr_len = addr_len as usize; - let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; - let mut addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; + let addr = { + let addr_len = addr_len as usize; + let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; + let addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; + addr + }; + trace!("bind to addr: {:?}", addr); let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { - let mut raw_addr = addr.to_raw(); - socket.bind(&mut raw_addr)?; + let raw_addr = addr.to_raw(); + socket.bind(&raw_addr)?; } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - let mut unix_addr = (addr.to_unix()?).clone(); - unix_socket.bind(&mut unix_addr)?; + let unix_addr = addr.to_unix()?; + unix_socket.bind(unix_addr)?; } else if let Ok(uring_socket) = file_ref.as_uring_socket() { - uring_socket.bind(&mut addr)?; + uring_socket.bind(&addr)?; } else { return_errno!(ENOTSOCK, "not a socket"); } @@ -159,18 +147,21 @@ pub fn do_connect( let file_ref = current!().file(fd as FileDesc)?; if let Ok(socket) = file_ref.as_host_socket() { let addr_option = if addr_set { - Some(unsafe { RawAddr::try_from_raw(addr, addr_len as u32)? }) + Some(unsafe { SockAddr::try_from_raw(addr, addr_len as u32)? }) } else { None }; - socket.connect(&addr_option)?; + socket.connect(addr_option.as_ref())?; return Ok(0); }; - let addr_len = addr_len as usize; - let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; - let addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; + let addr = { + let addr_len = addr_len as usize; + let sockaddr_storage = copy_sock_addr_from_user(addr, addr_len)?; + let addr = AnyAddr::from_c_storage(&sockaddr_storage, addr_len)?; + addr + }; if let Ok(unix_socket) = file_ref.as_unix_socket() { // TODO: support AF_UNSPEC address for datagram socket use @@ -415,19 +406,17 @@ pub fn do_sendto( let file_ref = current!().file(fd as FileDesc)?; if let Ok(host_socket) = file_ref.as_host_socket() { - let addr = addr.map(|any_addr| any_addr.to_raw()); - host_socket - .sendto(buf, send_flags, &addr) + .sendto(buf, send_flags, addr) .map(|u| u as isize) } else if let Ok(unix_socket) = file_ref.as_unix_socket() { let addr = match addr { - Some(any_addr) => Some(any_addr.to_unix()?.clone()), + Some(ref any_addr) => Some(any_addr.to_unix()?), None => None, }; unix_socket - .sendto(buf, send_flags, &addr) + .sendto(buf, send_flags, addr) .map(|u| u as isize) } else if let Ok(uring_socket) = file_ref.as_uring_socket() { uring_socket @@ -458,9 +447,7 @@ pub fn do_recvfrom( let file_ref = current!().file(fd as FileDesc)?; let (data_len, addr_recv) = if let Ok(socket) = file_ref.as_host_socket() { - socket - .recvfrom(buf, recv_flags) - .map(|(len, addr_recv)| (len, addr_recv.map(|raw_addr| AnyAddr::Raw(raw_addr))))? + socket.recvfrom(buf, recv_flags)? } else if let Ok(unix_socket) = file_ref.as_unix_socket() { unix_socket .recvfrom(buf, recv_flags) @@ -496,7 +483,7 @@ pub fn do_socketpair( let file_flags = SocketFlags::from_bits_truncate(socket_type); let close_on_spawn = file_flags.contains(SocketFlags::SOCK_CLOEXEC); - let sock_type = Type::try_from(socket_type & (!file_flags.bits())) + let sock_type = SocketType::try_from(socket_type & (!file_flags.bits())) .map_err(|_| errno!(EINVAL, "invalid socket type"))?; let domain = Domain::try_from(domain as u16)?; @@ -526,9 +513,8 @@ pub fn do_sendmsg(fd: c_int, msg_ptr: *const libc::msghdr, flags_c: c_int) -> Re let file_ref = current!().file(fd as FileDesc)?; if let Ok(host_socket) = file_ref.as_host_socket() { - let raw_addr = addr.map(|addr| addr.to_raw()); host_socket - .sendmsg(&bufs[..], flags, &raw_addr, control) + .sendmsg(&bufs[..], flags, addr, control) .map(|bytes_send| bytes_send as isize) } else if let Ok(socket) = file_ref.as_unix_socket() { socket @@ -554,20 +540,9 @@ pub fn do_recvmsg(fd: c_int, msg_mut_ptr: *mut libc::msghdr, flags_c: c_int) -> let file_ref = current!().file(fd as FileDesc)?; let (bytes_recv, recv_addr, msg_flags, msg_controllen) = if let Ok(host_socket) = file_ref.as_host_socket() { - host_socket.recvmsg(&mut bufs[..], flags, control).map( - |(bytes, addr, msg, controllen)| { - ( - bytes, - addr.map(|raw_addr| AnyAddr::Raw(raw_addr)), - msg, - controllen, - ) - }, - )? + host_socket.recvmsg(&mut bufs[..], flags, control)? } else if let Ok(unix_socket) = file_ref.as_unix_socket() { - unix_socket.recvmsg(&mut bufs[..], flags, control).map( - |(bytes_recvd, control_len)| (bytes_recvd, None, MsgFlags::empty(), control_len), - )? + unix_socket.recvmsg(&mut bufs[..], flags, control)? } else if let Ok(uring_socket) = file_ref.as_uring_socket() { uring_socket.recvmsg(&mut bufs[..], flags, control)? } else { @@ -610,13 +585,12 @@ pub fn do_sendmmsg( if let Ok(host_socket) = file_ref.as_host_socket() { for mmsg in (msgvec) { - let (any_addr, bufs, control) = extract_msghdr_from_user(&mmsg.msg_hdr)?; - let raw_addr = any_addr.map(|any_addr| any_addr.to_raw()); + let (addr, bufs, control) = extract_msghdr_from_user(&mmsg.msg_hdr)?; if host_socket - .sendmsg(&bufs[..], flags, &raw_addr, control) + .sendmsg(&bufs[..], flags, addr, control) .map(|bytes_send| { - mmsg.msg_len += bytes_send as c_uint; + mmsg.msg_len = bytes_send as c_uint; bytes_send as isize }) .is_ok() @@ -635,7 +609,7 @@ pub fn do_sendmmsg( if uring_socket .sendmsg(&bufs[..], addr, flags, control) .map(|bytes_send| { - mmsg.msg_len += bytes_send as c_uint; + mmsg.msg_len = bytes_send as c_uint; bytes_send as isize }) .is_ok() @@ -1021,18 +995,17 @@ fn copy_sock_addr_from_user( let sockaddr_src_buf = from_user::make_slice(addr as *const u8, addr_len)?; let sockaddr_storage = { - // Safety. The content will be initialized before function returns. - let mut sockaddr_storage = - unsafe { MaybeUninit::::uninit().assume_init() }; + let mut sockaddr_storage = MaybeUninit::::uninit(); // Safety. The dst slice is the only mutable reference to the sockaddr_storage let sockaddr_dst_buf = unsafe { - let ptr = &mut sockaddr_storage as *mut _ as *mut u8; + let ptr = sockaddr_storage.as_mut_ptr() as *mut u8; let len = addr_len; std::slice::from_raw_parts_mut(ptr, len) }; sockaddr_dst_buf.copy_from_slice(sockaddr_src_buf); - sockaddr_storage + unsafe { sockaddr_storage.assume_init() } }; + Ok(sockaddr_storage) } @@ -1088,7 +1061,7 @@ fn new_uring_getsockopt_cmd( level: i32, optname: i32, optlen: u32, - socket_type: Type, + socket_type: SocketType, ) -> Result> { if level != libc::SOL_SOCKET { return Ok(Box::new(GetSockOptRawCmd::new(level, optname, optlen))); @@ -1106,17 +1079,17 @@ fn new_uring_getsockopt_cmd( SockOptName::SO_RCVTIMEO_OLD => Box::new(GetRecvTimeoutCmd::new(())), SockOptName::SO_SNDTIMEO_OLD => Box::new(GetSendTimeoutCmd::new(())), SockOptName::SO_SNDBUF => { - if socket_type == Type::STREAM { + if socket_type == SocketType::STREAM { // Implement dynamic buf size for stream socket only. - Box::new(GetSndBufSizeCmd::new(())) + Box::new(GetSendBufSizeCmd::new(())) } else { Box::new(GetSockOptRawCmd::new(level, optname, optlen)) } } SockOptName::SO_RCVBUF => { - if socket_type == Type::STREAM { + if socket_type == SocketType::STREAM { // Implement dynamic buf size for stream socket only. - Box::new(GetRcvBufSizeCmd::new(())) + Box::new(GetRecvBufSizeCmd::new(())) } else { Box::new(GetSockOptRawCmd::new(level, optname, optlen)) } @@ -1163,161 +1136,153 @@ fn new_uring_setsockopt_cmd( level: i32, optname: i32, optval: &'static [u8], - socket_type: Type, + socket_type: SocketType, ) -> Result> { if level != libc::SOL_SOCKET { return Ok(Box::new(SetSockOptRawCmd::new(level, optname, optval))); } + if optval.len() == 0 { + return_errno!(EINVAL, "Not a valid optval length"); + } + let opt = SockOptName::try_from(optname).map_err(|_| errno!(ENOPROTOOPT, "Not a valid optname"))?; - let enable_uring = ENABLE_URING.load(Ordering::Relaxed); - if !enable_uring { - Ok(match opt { - SockOptName::SO_ACCEPTCONN - | SockOptName::SO_DOMAIN - | SockOptName::SO_PEERNAME - | SockOptName::SO_TYPE - | SockOptName::SO_ERROR - | SockOptName::SO_PEERCRED - | SockOptName::SO_SNDLOWAT - | SockOptName::SO_PEERSEC - | SockOptName::SO_PROTOCOL - | SockOptName::SO_MEMINFO - | SockOptName::SO_INCOMING_NAPI_ID - | SockOptName::SO_COOKIE - | SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"), - _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), - }) - } else { - Ok(match opt { - SockOptName::SO_ACCEPTCONN - | SockOptName::SO_DOMAIN - | SockOptName::SO_PEERNAME - | SockOptName::SO_TYPE - | SockOptName::SO_ERROR - | SockOptName::SO_PEERCRED - | SockOptName::SO_SNDLOWAT - | SockOptName::SO_PEERSEC - | SockOptName::SO_PROTOCOL - | SockOptName::SO_MEMINFO - | SockOptName::SO_INCOMING_NAPI_ID - | SockOptName::SO_COOKIE - | SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"), - SockOptName::SO_RCVTIMEO_OLD => { - let mut timeout: *const libc::timeval = std::ptr::null(); - if optval.len() >= std::mem::size_of::() { - timeout = optval as *const _ as *const libc::timeval; - } else { - return_errno!(EINVAL, "invalid timeout option"); - } - let timeout = unsafe { - let secs = if (*timeout).tv_sec < 0 { - 0 - } else { - (*timeout).tv_sec - }; - - let usec = (*timeout).tv_usec; - if usec < 0 || usec > 1000000 || (usec as u32).checked_mul(1000).is_none() { - return_errno!(EDOM, "time struct value is invalid"); - } - Duration::new(secs as u64, (*timeout).tv_usec as u32 * 1000) - }; - trace!("recv timeout = {:?}", timeout); - Box::new(SetRecvTimeoutCmd::new(timeout)) + Ok(match opt { + SockOptName::SO_ACCEPTCONN + | SockOptName::SO_DOMAIN + | SockOptName::SO_PEERNAME + | SockOptName::SO_TYPE + | SockOptName::SO_ERROR + | SockOptName::SO_PEERCRED + | SockOptName::SO_SNDLOWAT + | SockOptName::SO_PEERSEC + | SockOptName::SO_PROTOCOL + | SockOptName::SO_MEMINFO + | SockOptName::SO_INCOMING_NAPI_ID + | SockOptName::SO_COOKIE + | SockOptName::SO_PEERGROUPS => return_errno!(ENOPROTOOPT, "it's a read-only option"), + SockOptName::SO_RCVTIMEO_OLD => { + let mut timeout: *const libc::timeval = std::ptr::null(); + if optval.len() >= std::mem::size_of::() { + timeout = optval as *const _ as *const libc::timeval; + } else { + return_errno!(EINVAL, "invalid timeout option"); } - SockOptName::SO_SNDTIMEO_OLD => { - let mut timeout: *const libc::timeval = std::ptr::null(); - if optval.len() >= std::mem::size_of::() { - timeout = optval as *const _ as *const libc::timeval; + let timeout = unsafe { + let secs = if (*timeout).tv_sec < 0 { + 0 } else { - return_errno!(EINVAL, "invalid timeout option"); - } - let timeout = unsafe { - let secs = if (*timeout).tv_sec < 0 { - 0 - } else { - (*timeout).tv_sec - }; - - let usec = (*timeout).tv_usec; - if usec < 0 || usec > 1000000 || (usec as u32).checked_mul(1000).is_none() { - return_errno!(EDOM, "time struct value is invalid"); - } - Duration::new(secs as u64, usec as u32 * 1000) + (*timeout).tv_sec }; - trace!("send timeout = {:?}", timeout); - Box::new(SetSendTimeoutCmd::new(timeout)) + + let usec = (*timeout).tv_usec; + if usec < 0 || usec > 1000000 || (usec as u32).checked_mul(1000).is_none() { + return_errno!(EDOM, "time struct value is invalid"); + } + Duration::new(secs as u64, (*timeout).tv_usec as u32 * 1000) + }; + trace!("recv timeout = {:?}", timeout); + Box::new(SetRecvTimeoutCmd::new(timeout)) + } + SockOptName::SO_SNDTIMEO_OLD => { + let mut timeout: *const libc::timeval = std::ptr::null(); + if optval.len() >= std::mem::size_of::() { + timeout = optval as *const _ as *const libc::timeval; + } else { + return_errno!(EINVAL, "invalid timeout option"); } - SockOptName::SO_SNDBUF => { + let timeout = unsafe { + let secs = if (*timeout).tv_sec < 0 { + 0 + } else { + (*timeout).tv_sec + }; + + let usec = (*timeout).tv_usec; + if usec < 0 || usec > 1000000 || (usec as u32).checked_mul(1000).is_none() { + return_errno!(EDOM, "time struct value is invalid"); + } + Duration::new(secs as u64, usec as u32 * 1000) + }; + trace!("send timeout = {:?}", timeout); + Box::new(SetSendTimeoutCmd::new(timeout)) + } + SockOptName::SO_SNDBUF => { + // Implement dynamic buf size for stream socket only. + if socket_type != SocketType::STREAM { + Box::new(SetSockOptRawCmd::new(level, optname, optval)) + } else { + // Based on the man page: The minimum (doubled) value for this option is 2048. + // However, the test on Linux shows the minimum value (doubled) is 4608. Here, we just + // use the same value as Linux. + let min_size = 128 * 1024; + // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. + let max_size = 2 * 1024 * 1024; + + if optval.len() > 8 { + return_errno!(EINVAL, "optval size is invalid"); + } + + let mut send_buf_size = { + let mut size = [0 as u8; std::mem::size_of::()]; + let start_offset = size.len() - optval.len(); + size[start_offset..].copy_from_slice(optval); + usize::from_ne_bytes(size) + }; + trace!("set SO_SNDBUF size = {:?}", send_buf_size); + if send_buf_size < min_size { + send_buf_size = min_size; + } + if send_buf_size > max_size { + send_buf_size = max_size; + } + // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) + // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). + send_buf_size *= 2; + Box::new(SetSendBufSizeCmd::new(send_buf_size)) + } + } + SockOptName::SO_RCVBUF => { + if socket_type != SocketType::STREAM { + Box::new(SetSockOptRawCmd::new(level, optname, optval)) + } else { // Implement dynamic buf size for stream socket only. - if socket_type != Type::STREAM { - Box::new(SetSockOptRawCmd::new(level, optname, optval)) - } else { - // Based on the man page: The minimum (doubled) value for this option is 2048. - // However, the test on Linux shows the minimum value (doubled) is 4608. Here, we just - // use the same value as Linux. - let min_size = 1152; - // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. - let max_size = 2 * 1024 * 1024; + info!("optval = {:?}", optval); + // Based on the man page: The minimum (doubled) value for this option is 256. + // However, the test on Linux shows the minimum value (doubled) is 2304. Here, we just + // use the same value as Linux. + let min_size = 128 * 1024; + // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. + let max_size = 2 * 1024 * 1024; - let mut send_buf_size = { - let mut size = [0 as u8; std::mem::size_of::()]; - let start_offset = size.len() - optval.len(); - size[start_offset..].copy_from_slice(optval); - usize::from_ne_bytes(size) - }; - trace!("set SO_SNDBUF size = {:?}", send_buf_size); - if send_buf_size < min_size { - send_buf_size = min_size; - } - if send_buf_size > max_size { - send_buf_size = max_size; - } - // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) - // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). - send_buf_size *= 2; - Box::new(SetSndBufSizeCmd::new(send_buf_size)) + if optval.len() > 8 { + return_errno!(EINVAL, "optval size is invalid"); } - } - SockOptName::SO_RCVBUF => { - if socket_type != Type::STREAM { - Box::new(SetSockOptRawCmd::new(level, optname, optval)) - } else { - // Implement dynamic buf size for stream socket only. - info!("optval = {:?}", optval); - // Based on the man page: The minimum (doubled) value for this option is 256. - // However, the test on Linux shows the minimum value (doubled) is 2304. Here, we just - // use the same value as Linux. - let min_size = 1152; - // For the max value, we choose 4MB (doubled) to assure the libos kernel buf won't be the bottleneck. - let max_size = 2 * 1024 * 1024; - let mut recv_buf_size = { - let mut size = [0 as u8; std::mem::size_of::()]; - let start_offset = size.len() - optval.len(); - size[start_offset..].copy_from_slice(optval); - usize::from_ne_bytes(size) - }; - trace!("set SO_RCVBUF size = {:?}", recv_buf_size); - if recv_buf_size < min_size { - recv_buf_size = min_size; - } - - if recv_buf_size > max_size { - recv_buf_size = max_size - } - // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) - // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). - recv_buf_size *= 2; - Box::new(SetRcvBufSizeCmd::new(recv_buf_size)) + let mut recv_buf_size = { + let mut size = [0 as u8; std::mem::size_of::()]; + let start_offset = size.len() - optval.len(); + size[start_offset..].copy_from_slice(optval); + usize::from_ne_bytes(size) + }; + trace!("set SO_RCVBUF size = {:?}", recv_buf_size); + if recv_buf_size < min_size { + recv_buf_size = min_size; } + + if recv_buf_size > max_size { + recv_buf_size = max_size + } + // Based on man page: The kernel doubles this value (to allow space for bookkeeping overhead) + // when it is set using setsockopt(2), and this doubled value is returned by getsockopt(2). + recv_buf_size *= 2; + Box::new(SetRecvBufSizeCmd::new(recv_buf_size)) } - _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), - }) - } + } + _ => Box::new(SetSockOptRawCmd::new(level, optname, optval)), + }) } fn get_optval(cmd: &dyn IoctlCmd) -> Result<&[u8]> { @@ -1346,10 +1311,10 @@ fn get_optval(cmd: &dyn IoctlCmd) -> Result<&[u8]> { cmd : GetSendTimeoutCmd => { cmd.get_output_as_bytes() }, - cmd : GetSndBufSizeCmd => { + cmd : GetSendBufSizeCmd => { cmd.get_output_as_bytes() }, - cmd : GetRcvBufSizeCmd => { + cmd : GetRecvBufSizeCmd => { cmd.get_output_as_bytes() }, _ => { @@ -1491,11 +1456,3 @@ fn extract_msghdr_mut_from_user<'a>( Ok((msg_mut, name, control, bufs)) } - -// Flags to use when creating a new socket -bitflags! { - pub struct SocketFlags: i32 { - const SOCK_NONBLOCK = 0x800; - const SOCK_CLOEXEC = 0x80000; - } -} diff --git a/src/libos/src/prelude.rs b/src/libos/src/prelude.rs index 653fdfe3..6571a86c 100644 --- a/src/libos/src/prelude.rs +++ b/src/libos/src/prelude.rs @@ -17,8 +17,7 @@ pub use std::sync::{ pub use crate::error::Result; pub use crate::error::*; pub use crate::fs::{File, FileDesc, FileRef}; -pub use crate::net::socket::util::Addr; -pub use crate::net::socket::{Domain, RecvFlags, SendFlags, Shutdown, Type}; +pub use crate::net::{Addr, Domain, RecvFlags, SendFlags, Shutdown, SocketType}; pub use crate::process::{pid_t, uid_t}; pub use crate::util::sync::RwLock; pub use crate::util::sync::{Mutex, MutexGuard}; diff --git a/src/libos/src/process/mod.rs b/src/libos/src/process/mod.rs index 9300fa74..ec6f5aed 100644 --- a/src/libos/src/process/mod.rs +++ b/src/libos/src/process/mod.rs @@ -14,7 +14,6 @@ use crate::misc::ResourceLimits; use crate::prelude::*; use crate::sched::{NiceValue, SchedAgent}; use crate::signal::{SigDispositions, SigQueues}; -use crate::util::sync::Mutex; use crate::vm::ProcessVM; use self::pgrp::ProcessGrp; diff --git a/src/libos/src/util/sync/mutex.rs b/src/libos/src/util/sync/mutex.rs index 2701e237..9e31a588 100644 --- a/src/libos/src/util/sync/mutex.rs +++ b/src/libos/src/util/sync/mutex.rs @@ -11,20 +11,20 @@ use atomic::Ordering; use crate::process::{futex_wait, futex_wake}; #[derive(Default)] -pub struct Mutex { - value: UnsafeCell, +pub struct Mutex { inner: Box, + value: UnsafeCell, } -unsafe impl Sync for Mutex {} -unsafe impl Send for Mutex {} +unsafe impl Sync for Mutex {} +unsafe impl Send for Mutex {} -pub struct MutexGuard<'a, T: 'a> { +pub struct MutexGuard<'a, T: ?Sized + 'a> { inner: &'a Mutex, } -impl !Send for MutexGuard<'_, T> {} -unsafe impl Sync for MutexGuard<'_, T> {} +impl !Send for MutexGuard<'_, T> {} +unsafe impl Sync for MutexGuard<'_, T> {} impl Mutex { #[inline] @@ -34,14 +34,9 @@ impl Mutex { inner: Box::new(MutexInner::new()), } } - - #[inline] - pub fn into_inner(self) -> T { - self.value.into_inner() - } } -impl Mutex { +impl Mutex { #[inline] pub fn lock(&self) -> MutexGuard<'_, T> { self.inner.lock(); @@ -50,7 +45,7 @@ impl Mutex { #[inline] pub fn try_lock(&self) -> Option> { - self.inner.try_lock().map(|_| MutexGuard { inner: self }) + self.inner.try_lock().then(|| MutexGuard { inner: self }) } #[inline] @@ -67,9 +62,17 @@ impl Mutex { pub fn get_mut(&mut self) -> &mut T { self.value.get_mut() } + + #[inline] + pub fn into_inner(self) -> T + where + T: Sized, + { + self.value.into_inner() + } } -impl Deref for MutexGuard<'_, T> { +impl Deref for MutexGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -77,13 +80,13 @@ impl Deref for MutexGuard<'_, T> { } } -impl DerefMut for MutexGuard<'_, T> { +impl DerefMut for MutexGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { &mut *self.inner.value.get() } } } -impl Drop for MutexGuard<'_, T> { +impl Drop for MutexGuard<'_, T> { fn drop(&mut self) { unsafe { self.inner.force_unlock(); @@ -91,13 +94,13 @@ impl Drop for MutexGuard<'_, T> { } } -impl fmt::Debug for MutexGuard<'_, T> { +impl fmt::Debug for MutexGuard<'_, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Debug::fmt(&**self, f) } } -impl fmt::Debug for Mutex { +impl fmt::Debug for Mutex { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.try_lock() { Some(guard) => write!(f, "Mutex {{ value: ") @@ -138,10 +141,10 @@ impl MutexInner { } #[inline] - pub fn try_lock(&self) -> Option { + pub fn try_lock(&self) -> bool { self.lock .compare_exchange(0, 1, Ordering::Acquire, Ordering::Relaxed) - .ok() + .is_ok() } #[inline]