From cd2f13ae5412e6c497b0456ed9f8307bfd793d41 Mon Sep 17 00:00:00 2001 From: He Sun Date: Thu, 14 May 2020 19:32:34 +0800 Subject: [PATCH] Refactor select syscall 1. Substitute the underlying poll OCall to select OCall to update the timeout argument correctly. 2. Add more checks for the inputs. --- src/Enclave.edl | 8 + src/libos/include/edl/occlum_edl_types.h | 5 + src/libos/src/net/io_multiplexing/mod.rs | 3 +- src/libos/src/net/io_multiplexing/select.rs | 237 +++++++++++++++----- src/libos/src/net/syscalls.rs | 45 ++-- src/libos/src/syscall/mod.rs | 2 +- src/pal/include/edl/occlum_edl_types.h | 1 + src/pal/src/ocalls/net.c | 9 + test/eventfd/main.c | 16 +- 9 files changed, 252 insertions(+), 74 deletions(-) diff --git a/src/Enclave.edl b/src/Enclave.edl index 0bcba69f..c4ea828a 100644 --- a/src/Enclave.edl +++ b/src/Enclave.edl @@ -161,6 +161,14 @@ enclave { int flags ) propagate_errno; + int occlum_ocall_select( + int nfds, + [in, out] fd_set *readfds, + [in, out] fd_set *writefds, + [in, out] fd_set *exceptfds, + [in, out] struct timeval *timeout + ) propagate_errno; + void occlum_ocall_print_log(uint32_t level, [in, string] const char* msg); void occlum_ocall_flush_log(void); diff --git a/src/libos/include/edl/occlum_edl_types.h b/src/libos/include/edl/occlum_edl_types.h index 98160112..1759a640 100644 --- a/src/libos/include/edl/occlum_edl_types.h +++ b/src/libos/include/edl/occlum_edl_types.h @@ -16,4 +16,9 @@ struct occlum_stdio_fds { int stderr_fd; }; +#define FD_SETSIZE 1024 +typedef struct { + unsigned long fds_bits[FD_SETSIZE / 8 / sizeof(long)]; +} fd_set; + #endif /* __OCCLUM_EDL_TYPES_H__ */ diff --git a/src/libos/src/net/io_multiplexing/mod.rs b/src/libos/src/net/io_multiplexing/mod.rs index d440577c..df7e7fcd 100644 --- a/src/libos/src/net/io_multiplexing/mod.rs +++ b/src/libos/src/net/io_multiplexing/mod.rs @@ -6,10 +6,11 @@ mod select; pub use self::epoll::{AsEpollFile, EpollCtlCmd, EpollEvent, EpollEventFlags, EpollFile}; pub use self::poll::do_poll; -pub use self::select::do_select; +pub use self::select::{select, FdSetExt}; use fs::{AsDevRandom, AsEvent, CreationFlags, File, FileDesc, FileRef}; use std::any::Any; use std::convert::TryFrom; use std::fmt; use std::sync::atomic::spin_loop_hint; +use time::timeval_t; diff --git a/src/libos/src/net/io_multiplexing/select.rs b/src/libos/src/net/io_multiplexing/select.rs index 1ecc4236..c7920681 100644 --- a/src/libos/src/net/io_multiplexing/select.rs +++ b/src/libos/src/net/io_multiplexing/select.rs @@ -1,24 +1,24 @@ use super::*; -/// Forward to host `poll` -/// (sgx_libc doesn't have `select`) -pub fn do_select( - nfds: usize, +pub fn select( + nfds: c_int, readfds: &mut libc::fd_set, writefds: &mut libc::fd_set, exceptfds: &mut libc::fd_set, - timeout: Option, -) -> Result { - debug!("select: nfds: {}", nfds); - // convert libos fd to Linux fd - let mut host_to_libos_fd = [0; libc::FD_SETSIZE]; - let mut polls = Vec::::new(); + timeout: Option<&mut timeval_t>, +) -> Result { + debug!("select: nfds: {} timeout: {:?}", nfds, timeout); let current = current!(); let file_table = current.files().lock().unwrap(); - for fd in 0..nfds { - let fd_ref = file_table.get(fd as FileDesc)?; + let mut max_host_fd = None; + let mut host_to_libos_fd = [None; libc::FD_SETSIZE]; + let mut unsafe_readfds = libc::fd_set::new_empty(); + let mut unsafe_writefds = libc::fd_set::new_empty(); + let mut unsafe_exceptfds = libc::fd_set::new_empty(); + + for fd in 0..(nfds as FileDesc) { let (r, w, e) = ( readfds.is_set(fd), writefds.is_set(fd), @@ -27,6 +27,9 @@ pub fn do_select( if !(r || w || e) { continue; } + + let fd_ref = file_table.get(fd)?; + if let Ok(socket) = fd_ref.as_unix_socket() { warn!("select unix socket is unimplemented, spin for read"); readfds.clear(); @@ -39,89 +42,187 @@ pub fn do_select( } let (rr, ww, ee) = socket.poll()?; + let mut ready_num = 0; if r && rr { - readfds.set(fd); + readfds.set(fd)?; + ready_num += 1; } if w && ww { - writefds.set(fd); + writefds.set(fd)?; + ready_num += 1; } if e && ee { - writefds.set(fd); + exceptfds.set(fd)?; + ready_num += 1; } - return Ok(1); + return Ok(ready_num); } + let host_fd = if let Ok(socket) = fd_ref.as_socket() { socket.fd() } else if let Ok(eventfd) = fd_ref.as_event() { eventfd.get_host_fd() } else { return_errno!(EBADF, "unsupported file type"); - }; + } as FileDesc; - host_to_libos_fd[host_fd as usize] = fd; - let mut events = 0; + if host_fd as usize >= libc::FD_SETSIZE { + return_errno!(EBADF, "host fd exceeds FD_SETSIZE"); + } + + // convert libos fd to host fd + host_to_libos_fd[host_fd as usize] = Some(fd); + max_host_fd = Some(max(max_host_fd.unwrap_or(0), host_fd as c_int)); if r { - events |= libc::POLLIN; + unsafe_readfds.set(host_fd)?; } if w { - events |= libc::POLLOUT; + unsafe_writefds.set(host_fd)?; } if e { - events |= libc::POLLERR; + unsafe_exceptfds.set(host_fd)?; } - - polls.push(libc::pollfd { - fd: host_fd as c_int, - events, - revents: 0, - }); } // Unlock the file table as early as possible drop(file_table); - let timeout = match timeout { - None => -1, - Some(tv) => (tv.tv_sec * 1000 + tv.tv_usec / 1000) as i32, + let host_nfds = if let Some(fd) = max_host_fd { + fd + 1 + } else { + // Set nfds to zero if no fd is monitored + 0 }; - let (polls_ptr, polls_len) = polls.as_mut_slice().as_mut_ptr_and_len(); - let ret = try_libc!(libc::ocall::poll(polls_ptr, polls_len as u64, timeout)); + let ret = do_select_in_host( + host_nfds, + &mut unsafe_readfds, + &mut unsafe_writefds, + &mut unsafe_exceptfds, + timeout, + )?; - // convert fd back and write fdset - readfds.clear(); - writefds.clear(); - exceptfds.clear(); + // convert fd back and write fdset and do ocall check + let mut ready_num = 0; + for host_fd in 0..host_nfds as FileDesc { + let fd_option = host_to_libos_fd[host_fd as usize]; + let (r, w, e) = ( + unsafe_readfds.is_set(host_fd), + unsafe_writefds.is_set(host_fd), + unsafe_exceptfds.is_set(host_fd), + ); + if !(r || w || e) { + if let Some(fd) = fd_option { + readfds.unset(fd)?; + writefds.unset(fd)?; + exceptfds.unset(fd)?; + } + continue; + } - for poll in polls.iter() { - let fd = host_to_libos_fd[poll.fd as usize]; - if poll.revents & libc::POLLIN != 0 { - readfds.set(fd); + let fd = fd_option.expect("host_fd with events must have a responding libos fd"); + + if r { + assert!(readfds.is_set(fd)); + ready_num += 1; + } else { + readfds.unset(fd)?; } - if poll.revents & libc::POLLOUT != 0 { - writefds.set(fd); + if w { + assert!(writefds.is_set(fd)); + ready_num += 1; + } else { + writefds.unset(fd)?; } - if poll.revents & libc::POLLERR != 0 { - exceptfds.set(fd); + if e { + assert!(exceptfds.is_set(fd)); + ready_num += 1; + } else { + exceptfds.unset(fd)?; } } - Ok(ret as usize) + assert!(ready_num == ret); + Ok(ret) +} + +fn do_select_in_host( + host_nfds: c_int, + readfds: &mut libc::fd_set, + writefds: &mut libc::fd_set, + exceptfds: &mut libc::fd_set, + timeout: Option<&mut timeval_t>, +) -> Result { + let readfds_ptr = readfds.as_raw_ptr_mut(); + let writefds_ptr = writefds.as_raw_ptr_mut(); + let exceptfds_ptr = exceptfds.as_raw_ptr_mut(); + + let mut origin_timeout: timeval_t = Default::default(); + let timeout_ptr = if let Some(to) = timeout { + origin_timeout = *to; + to + } else { + std::ptr::null_mut() + } as *mut timeval_t; + + let ret = try_libc!({ + let mut retval: c_int = 0; + let status = occlum_ocall_select( + &mut retval, + host_nfds, + readfds_ptr, + writefds_ptr, + exceptfds_ptr, + timeout_ptr, + ); + assert!(status == sgx_status_t::SGX_SUCCESS); + + retval + }) as isize; + + if !timeout_ptr.is_null() { + let time_left = unsafe { *(timeout_ptr) }; + time_left.validate()?; + assert!(time_left.as_duration() <= origin_timeout.as_duration()); + } + + Ok(ret) } /// Safe methods for `libc::fd_set` -trait FdSetExt { - fn set(&mut self, fd: usize); +pub trait FdSetExt { + fn new_empty() -> Self; + fn unset(&mut self, fd: FileDesc) -> Result<()>; + fn is_set(&self, fd: FileDesc) -> bool; + fn set(&mut self, fd: FileDesc) -> Result<()>; fn clear(&mut self); - fn is_set(&mut self, fd: usize) -> bool; + fn is_empty(&self) -> bool; + fn as_raw_ptr_mut(&mut self) -> *mut Self; } impl FdSetExt for libc::fd_set { - fn set(&mut self, fd: usize) { - assert!(fd < libc::FD_SETSIZE); + fn new_empty() -> Self { + unsafe { core::mem::zeroed() } + } + + fn unset(&mut self, fd: FileDesc) -> Result<()> { + if fd as usize >= libc::FD_SETSIZE { + return_errno!(EINVAL, "fd exceeds FD_SETSIZE"); + } + unsafe { + libc::FD_CLR(fd as c_int, self); + } + Ok(()) + } + + fn set(&mut self, fd: FileDesc) -> Result<()> { + if fd as usize >= libc::FD_SETSIZE { + return_errno!(EINVAL, "fd exceeds FD_SETSIZE"); + } unsafe { libc::FD_SET(fd as c_int, self); } + Ok(()) } fn clear(&mut self) { @@ -130,8 +231,36 @@ impl FdSetExt for libc::fd_set { } } - fn is_set(&mut self, fd: usize) -> bool { - assert!(fd < libc::FD_SETSIZE); - unsafe { libc::FD_ISSET(fd as c_int, self) } + fn is_set(&self, fd: FileDesc) -> bool { + if fd as usize >= libc::FD_SETSIZE { + return false; + } + unsafe { libc::FD_ISSET(fd as c_int, self as *const Self as *mut Self) } + } + + fn is_empty(&self) -> bool { + let set = unsafe { + std::slice::from_raw_parts(self as *const Self as *const u64, libc::FD_SETSIZE / 64) + }; + set.iter().all(|&x| x == 0) + } + + fn as_raw_ptr_mut(&mut self) -> *mut Self { + if self.is_empty() { + std::ptr::null_mut() + } else { + self as *mut libc::fd_set + } } } + +extern "C" { + fn occlum_ocall_select( + ret: *mut c_int, + nfds: c_int, + readfds: *mut libc::fd_set, + writefds: *mut libc::fd_set, + exceptfds: *mut libc::fd_set, + timeout: *mut timeval_t, + ) -> sgx_status_t; +} diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index 67ed5a09..364e31a5 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -1,10 +1,11 @@ use super::*; -use super::io_multiplexing::{AsEpollFile, EpollCtlCmd, EpollEventFlags, EpollFile}; +use super::io_multiplexing::{AsEpollFile, EpollCtlCmd, EpollEventFlags, EpollFile, FdSetExt}; use fs::{CreationFlags, File, FileDesc, FileRef}; use misc::resource_t; use process::Process; use std::convert::TryFrom; +use time::timeval_t; use util::mem_util::from_user; pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result { @@ -129,45 +130,59 @@ pub fn do_select( readfds: *mut libc::fd_set, writefds: *mut libc::fd_set, exceptfds: *mut libc::fd_set, - timeout: *const libc::timeval, + timeout: *mut timeval_t, ) -> Result { // check arguments - if nfds < 0 || nfds >= libc::FD_SETSIZE as c_int { - return_errno!(EINVAL, "nfds is negative or exceeds the resource limit"); + let soft_rlimit_nofile = current!() + .rlimits() + .lock() + .unwrap() + .get(resource_t::RLIMIT_NOFILE) + .get_cur(); + if nfds < 0 || nfds > libc::FD_SETSIZE as i32 || nfds as u64 > soft_rlimit_nofile { + return_errno!( + EINVAL, + "nfds is negative or exceeds the resource limit or FD_SETSIZE" + ); } - let nfds = nfds as usize; - let mut zero_fds0: libc::fd_set = unsafe { core::mem::zeroed() }; - let mut zero_fds1: libc::fd_set = unsafe { core::mem::zeroed() }; - let mut zero_fds2: libc::fd_set = unsafe { core::mem::zeroed() }; + // Select handles empty set and null in the same way + // TODO: Elegently handle the empty fd_set without allocating redundant fd_set + let mut empty_set_for_read = libc::fd_set::new_empty(); + let mut empty_set_for_write = libc::fd_set::new_empty(); + let mut empty_set_for_except = libc::fd_set::new_empty(); let readfds = if !readfds.is_null() { from_user::check_mut_ptr(readfds)?; unsafe { &mut *readfds } } else { - &mut zero_fds0 + &mut empty_set_for_read }; let writefds = if !writefds.is_null() { from_user::check_mut_ptr(writefds)?; unsafe { &mut *writefds } } else { - &mut zero_fds1 + &mut empty_set_for_write }; let exceptfds = if !exceptfds.is_null() { from_user::check_mut_ptr(exceptfds)?; unsafe { &mut *exceptfds } } else { - &mut zero_fds2 + &mut empty_set_for_except }; - let timeout = if !timeout.is_null() { + + let timeout_option = if !timeout.is_null() { from_user::check_ptr(timeout)?; - Some(unsafe { timeout.read() }) + unsafe { + (*timeout).validate()?; + Some(&mut *timeout) + } } else { None }; - let n = io_multiplexing::do_select(nfds, readfds, writefds, exceptfds, timeout)?; - Ok(n as isize) + let ret = io_multiplexing::select(nfds, readfds, writefds, exceptfds, timeout_option)?; + Ok(ret) } pub fn do_poll(fds: *mut libc::pollfd, nfds: libc::nfds_t, timeout: c_int) -> Result { diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index 624dabf9..b56ed062 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -98,7 +98,7 @@ macro_rules! process_syscall_table_with_callback { (Writev = 20) => do_writev(fd: FileDesc, iov: *const iovec_t, count: i32), (Access = 21) => do_access(path: *const i8, mode: u32), (Pipe = 22) => do_pipe(fds_u: *mut i32), - (Select = 23) => do_select(nfds: c_int, readfds: *mut libc::fd_set, writefds: *mut libc::fd_set, exceptfds: *mut libc::fd_set, timeout: *const libc::timeval), + (Select = 23) => do_select(nfds: c_int, readfds: *mut libc::fd_set, writefds: *mut libc::fd_set, exceptfds: *mut libc::fd_set, timeout: *mut timeval_t), (SchedYield = 24) => do_sched_yield(), (Mremap = 25) => do_mremap(old_addr: usize, old_size: usize, new_size: usize, flags: i32, new_addr: usize), (Msync = 26) => handle_unsupported(), diff --git a/src/pal/include/edl/occlum_edl_types.h b/src/pal/include/edl/occlum_edl_types.h index 06c794a8..8994d714 100644 --- a/src/pal/include/edl/occlum_edl_types.h +++ b/src/pal/include/edl/occlum_edl_types.h @@ -2,6 +2,7 @@ #define __OCCLUM_EDL_TYPES__ #include // import struct timespec +#include // import fd_set #include // import struct timeval #include // import struct iovec #include // import occlum_stdio_fds diff --git a/src/pal/src/ocalls/net.c b/src/pal/src/ocalls/net.c index 704889df..bc699061 100644 --- a/src/pal/src/ocalls/net.c +++ b/src/pal/src/ocalls/net.c @@ -1,4 +1,5 @@ #include +#include #include #include #include "ocalls.h" @@ -53,3 +54,11 @@ ssize_t occlum_ocall_recvmsg(int sockfd, *msg_flags_recv = msg.msg_flags; return ret; } + +int occlum_ocall_select(int nfds, + fd_set *readfds, + fd_set *writefds, + fd_set *exceptfds, + struct timeval *timeout) { + return select(nfds, readfds, writefds, exceptfds, timeout); +} diff --git a/test/eventfd/main.c b/test/eventfd/main.c index 1d04fcc0..9af406b5 100644 --- a/test/eventfd/main.c +++ b/test/eventfd/main.c @@ -145,7 +145,8 @@ int test_read_write() { } int test_select_with_socket() { - fd_set wfds; + fd_set rfds, wfds; + int ret = 0; struct timeval tv = { .tv_sec = 60, .tv_usec = 0 }; @@ -155,15 +156,24 @@ int test_select_with_socket() { THROW_ERROR("failed to create files"); } + FD_ZERO(&rfds); FD_ZERO(&wfds); + FD_SET(sock, &rfds); FD_SET(sock, &wfds); + FD_SET(event_fd, &rfds); FD_SET(event_fd, &wfds); - - if (select(sock > event_fd? sock + 1: event_fd + 1, NULL, &wfds, NULL, &tv) <= 0) { + ret = select(sock > event_fd? sock + 1: event_fd + 1, &rfds, &wfds, NULL, &tv); + if (ret != 3) { close_files(2, sock, event_fd); THROW_ERROR("select failed"); } + if (FD_ISSET(event_fd, &rfds) == 1 || FD_ISSET(event_fd, &wfds) == 0 || + FD_ISSET(sock, &rfds) == 0 || FD_ISSET(sock, &wfds) == 0) { + close_files(2, sock, event_fd); + THROW_ERROR("bad select return"); + } + close_files(2, sock, event_fd); return 0; }