Refactor select syscall

1. Substitute the underlying poll OCall to select OCall to update the
timeout argument correctly.
2. Add more checks for the inputs.
This commit is contained in:
He Sun 2020-05-14 19:32:34 +08:00
parent c14ee62678
commit cd2f13ae54
9 changed files with 252 additions and 74 deletions

@ -161,6 +161,14 @@ enclave {
int flags int flags
) propagate_errno; ) propagate_errno;
int occlum_ocall_select(
int nfds,
[in, out] fd_set *readfds,
[in, out] fd_set *writefds,
[in, out] fd_set *exceptfds,
[in, out] struct timeval *timeout
) propagate_errno;
void occlum_ocall_print_log(uint32_t level, [in, string] const char* msg); void occlum_ocall_print_log(uint32_t level, [in, string] const char* msg);
void occlum_ocall_flush_log(void); void occlum_ocall_flush_log(void);

@ -16,4 +16,9 @@ struct occlum_stdio_fds {
int stderr_fd; int stderr_fd;
}; };
#define FD_SETSIZE 1024
typedef struct {
unsigned long fds_bits[FD_SETSIZE / 8 / sizeof(long)];
} fd_set;
#endif /* __OCCLUM_EDL_TYPES_H__ */ #endif /* __OCCLUM_EDL_TYPES_H__ */

@ -6,10 +6,11 @@ mod select;
pub use self::epoll::{AsEpollFile, EpollCtlCmd, EpollEvent, EpollEventFlags, EpollFile}; pub use self::epoll::{AsEpollFile, EpollCtlCmd, EpollEvent, EpollEventFlags, EpollFile};
pub use self::poll::do_poll; pub use self::poll::do_poll;
pub use self::select::do_select; pub use self::select::{select, FdSetExt};
use fs::{AsDevRandom, AsEvent, CreationFlags, File, FileDesc, FileRef}; use fs::{AsDevRandom, AsEvent, CreationFlags, File, FileDesc, FileRef};
use std::any::Any; use std::any::Any;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt; use std::fmt;
use std::sync::atomic::spin_loop_hint; use std::sync::atomic::spin_loop_hint;
use time::timeval_t;

@ -1,24 +1,24 @@
use super::*; use super::*;
/// Forward to host `poll` pub fn select(
/// (sgx_libc doesn't have `select`) nfds: c_int,
pub fn do_select(
nfds: usize,
readfds: &mut libc::fd_set, readfds: &mut libc::fd_set,
writefds: &mut libc::fd_set, writefds: &mut libc::fd_set,
exceptfds: &mut libc::fd_set, exceptfds: &mut libc::fd_set,
timeout: Option<libc::timeval>, timeout: Option<&mut timeval_t>,
) -> Result<usize> { ) -> Result<isize> {
debug!("select: nfds: {}", nfds); debug!("select: nfds: {} timeout: {:?}", nfds, timeout);
// convert libos fd to Linux fd
let mut host_to_libos_fd = [0; libc::FD_SETSIZE];
let mut polls = Vec::<libc::pollfd>::new();
let current = current!(); let current = current!();
let file_table = current.files().lock().unwrap(); let file_table = current.files().lock().unwrap();
for fd in 0..nfds { let mut max_host_fd = None;
let fd_ref = file_table.get(fd as FileDesc)?; let mut host_to_libos_fd = [None; libc::FD_SETSIZE];
let mut unsafe_readfds = libc::fd_set::new_empty();
let mut unsafe_writefds = libc::fd_set::new_empty();
let mut unsafe_exceptfds = libc::fd_set::new_empty();
for fd in 0..(nfds as FileDesc) {
let (r, w, e) = ( let (r, w, e) = (
readfds.is_set(fd), readfds.is_set(fd),
writefds.is_set(fd), writefds.is_set(fd),
@ -27,6 +27,9 @@ pub fn do_select(
if !(r || w || e) { if !(r || w || e) {
continue; continue;
} }
let fd_ref = file_table.get(fd)?;
if let Ok(socket) = fd_ref.as_unix_socket() { if let Ok(socket) = fd_ref.as_unix_socket() {
warn!("select unix socket is unimplemented, spin for read"); warn!("select unix socket is unimplemented, spin for read");
readfds.clear(); readfds.clear();
@ -39,89 +42,187 @@ pub fn do_select(
} }
let (rr, ww, ee) = socket.poll()?; let (rr, ww, ee) = socket.poll()?;
let mut ready_num = 0;
if r && rr { if r && rr {
readfds.set(fd); readfds.set(fd)?;
ready_num += 1;
} }
if w && ww { if w && ww {
writefds.set(fd); writefds.set(fd)?;
ready_num += 1;
} }
if e && ee { if e && ee {
writefds.set(fd); exceptfds.set(fd)?;
ready_num += 1;
} }
return Ok(1); return Ok(ready_num);
} }
let host_fd = if let Ok(socket) = fd_ref.as_socket() { let host_fd = if let Ok(socket) = fd_ref.as_socket() {
socket.fd() socket.fd()
} else if let Ok(eventfd) = fd_ref.as_event() { } else if let Ok(eventfd) = fd_ref.as_event() {
eventfd.get_host_fd() eventfd.get_host_fd()
} else { } else {
return_errno!(EBADF, "unsupported file type"); return_errno!(EBADF, "unsupported file type");
}; } as FileDesc;
host_to_libos_fd[host_fd as usize] = fd; if host_fd as usize >= libc::FD_SETSIZE {
let mut events = 0; return_errno!(EBADF, "host fd exceeds FD_SETSIZE");
}
// convert libos fd to host fd
host_to_libos_fd[host_fd as usize] = Some(fd);
max_host_fd = Some(max(max_host_fd.unwrap_or(0), host_fd as c_int));
if r { if r {
events |= libc::POLLIN; unsafe_readfds.set(host_fd)?;
} }
if w { if w {
events |= libc::POLLOUT; unsafe_writefds.set(host_fd)?;
} }
if e { if e {
events |= libc::POLLERR; unsafe_exceptfds.set(host_fd)?;
} }
polls.push(libc::pollfd {
fd: host_fd as c_int,
events,
revents: 0,
});
} }
// Unlock the file table as early as possible // Unlock the file table as early as possible
drop(file_table); drop(file_table);
let timeout = match timeout { let host_nfds = if let Some(fd) = max_host_fd {
None => -1, fd + 1
Some(tv) => (tv.tv_sec * 1000 + tv.tv_usec / 1000) as i32, } else {
// Set nfds to zero if no fd is monitored
0
}; };
let (polls_ptr, polls_len) = polls.as_mut_slice().as_mut_ptr_and_len(); let ret = do_select_in_host(
let ret = try_libc!(libc::ocall::poll(polls_ptr, polls_len as u64, timeout)); host_nfds,
&mut unsafe_readfds,
&mut unsafe_writefds,
&mut unsafe_exceptfds,
timeout,
)?;
// convert fd back and write fdset // convert fd back and write fdset and do ocall check
readfds.clear(); let mut ready_num = 0;
writefds.clear(); for host_fd in 0..host_nfds as FileDesc {
exceptfds.clear(); let fd_option = host_to_libos_fd[host_fd as usize];
let (r, w, e) = (
unsafe_readfds.is_set(host_fd),
unsafe_writefds.is_set(host_fd),
unsafe_exceptfds.is_set(host_fd),
);
if !(r || w || e) {
if let Some(fd) = fd_option {
readfds.unset(fd)?;
writefds.unset(fd)?;
exceptfds.unset(fd)?;
}
continue;
}
for poll in polls.iter() { let fd = fd_option.expect("host_fd with events must have a responding libos fd");
let fd = host_to_libos_fd[poll.fd as usize];
if poll.revents & libc::POLLIN != 0 { if r {
readfds.set(fd); assert!(readfds.is_set(fd));
ready_num += 1;
} else {
readfds.unset(fd)?;
} }
if poll.revents & libc::POLLOUT != 0 { if w {
writefds.set(fd); assert!(writefds.is_set(fd));
ready_num += 1;
} else {
writefds.unset(fd)?;
} }
if poll.revents & libc::POLLERR != 0 { if e {
exceptfds.set(fd); assert!(exceptfds.is_set(fd));
ready_num += 1;
} else {
exceptfds.unset(fd)?;
} }
} }
Ok(ret as usize) assert!(ready_num == ret);
Ok(ret)
}
fn do_select_in_host(
host_nfds: c_int,
readfds: &mut libc::fd_set,
writefds: &mut libc::fd_set,
exceptfds: &mut libc::fd_set,
timeout: Option<&mut timeval_t>,
) -> Result<isize> {
let readfds_ptr = readfds.as_raw_ptr_mut();
let writefds_ptr = writefds.as_raw_ptr_mut();
let exceptfds_ptr = exceptfds.as_raw_ptr_mut();
let mut origin_timeout: timeval_t = Default::default();
let timeout_ptr = if let Some(to) = timeout {
origin_timeout = *to;
to
} else {
std::ptr::null_mut()
} as *mut timeval_t;
let ret = try_libc!({
let mut retval: c_int = 0;
let status = occlum_ocall_select(
&mut retval,
host_nfds,
readfds_ptr,
writefds_ptr,
exceptfds_ptr,
timeout_ptr,
);
assert!(status == sgx_status_t::SGX_SUCCESS);
retval
}) as isize;
if !timeout_ptr.is_null() {
let time_left = unsafe { *(timeout_ptr) };
time_left.validate()?;
assert!(time_left.as_duration() <= origin_timeout.as_duration());
}
Ok(ret)
} }
/// Safe methods for `libc::fd_set` /// Safe methods for `libc::fd_set`
trait FdSetExt { pub trait FdSetExt {
fn set(&mut self, fd: usize); fn new_empty() -> Self;
fn unset(&mut self, fd: FileDesc) -> Result<()>;
fn is_set(&self, fd: FileDesc) -> bool;
fn set(&mut self, fd: FileDesc) -> Result<()>;
fn clear(&mut self); fn clear(&mut self);
fn is_set(&mut self, fd: usize) -> bool; fn is_empty(&self) -> bool;
fn as_raw_ptr_mut(&mut self) -> *mut Self;
} }
impl FdSetExt for libc::fd_set { impl FdSetExt for libc::fd_set {
fn set(&mut self, fd: usize) { fn new_empty() -> Self {
assert!(fd < libc::FD_SETSIZE); unsafe { core::mem::zeroed() }
}
fn unset(&mut self, fd: FileDesc) -> Result<()> {
if fd as usize >= libc::FD_SETSIZE {
return_errno!(EINVAL, "fd exceeds FD_SETSIZE");
}
unsafe {
libc::FD_CLR(fd as c_int, self);
}
Ok(())
}
fn set(&mut self, fd: FileDesc) -> Result<()> {
if fd as usize >= libc::FD_SETSIZE {
return_errno!(EINVAL, "fd exceeds FD_SETSIZE");
}
unsafe { unsafe {
libc::FD_SET(fd as c_int, self); libc::FD_SET(fd as c_int, self);
} }
Ok(())
} }
fn clear(&mut self) { fn clear(&mut self) {
@ -130,8 +231,36 @@ impl FdSetExt for libc::fd_set {
} }
} }
fn is_set(&mut self, fd: usize) -> bool { fn is_set(&self, fd: FileDesc) -> bool {
assert!(fd < libc::FD_SETSIZE); if fd as usize >= libc::FD_SETSIZE {
unsafe { libc::FD_ISSET(fd as c_int, self) } return false;
}
unsafe { libc::FD_ISSET(fd as c_int, self as *const Self as *mut Self) }
}
fn is_empty(&self) -> bool {
let set = unsafe {
std::slice::from_raw_parts(self as *const Self as *const u64, libc::FD_SETSIZE / 64)
};
set.iter().all(|&x| x == 0)
}
fn as_raw_ptr_mut(&mut self) -> *mut Self {
if self.is_empty() {
std::ptr::null_mut()
} else {
self as *mut libc::fd_set
} }
} }
}
extern "C" {
fn occlum_ocall_select(
ret: *mut c_int,
nfds: c_int,
readfds: *mut libc::fd_set,
writefds: *mut libc::fd_set,
exceptfds: *mut libc::fd_set,
timeout: *mut timeval_t,
) -> sgx_status_t;
}

@ -1,10 +1,11 @@
use super::*; use super::*;
use super::io_multiplexing::{AsEpollFile, EpollCtlCmd, EpollEventFlags, EpollFile}; use super::io_multiplexing::{AsEpollFile, EpollCtlCmd, EpollEventFlags, EpollFile, FdSetExt};
use fs::{CreationFlags, File, FileDesc, FileRef}; use fs::{CreationFlags, File, FileDesc, FileRef};
use misc::resource_t; use misc::resource_t;
use process::Process; use process::Process;
use std::convert::TryFrom; use std::convert::TryFrom;
use time::timeval_t;
use util::mem_util::from_user; use util::mem_util::from_user;
pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<isize> { pub fn do_sendmsg(fd: c_int, msg_ptr: *const msghdr, flags_c: c_int) -> Result<isize> {
@ -129,45 +130,59 @@ pub fn do_select(
readfds: *mut libc::fd_set, readfds: *mut libc::fd_set,
writefds: *mut libc::fd_set, writefds: *mut libc::fd_set,
exceptfds: *mut libc::fd_set, exceptfds: *mut libc::fd_set,
timeout: *const libc::timeval, timeout: *mut timeval_t,
) -> Result<isize> { ) -> Result<isize> {
// check arguments // check arguments
if nfds < 0 || nfds >= libc::FD_SETSIZE as c_int { let soft_rlimit_nofile = current!()
return_errno!(EINVAL, "nfds is negative or exceeds the resource limit"); .rlimits()
.lock()
.unwrap()
.get(resource_t::RLIMIT_NOFILE)
.get_cur();
if nfds < 0 || nfds > libc::FD_SETSIZE as i32 || nfds as u64 > soft_rlimit_nofile {
return_errno!(
EINVAL,
"nfds is negative or exceeds the resource limit or FD_SETSIZE"
);
} }
let nfds = nfds as usize;
let mut zero_fds0: libc::fd_set = unsafe { core::mem::zeroed() }; // Select handles empty set and null in the same way
let mut zero_fds1: libc::fd_set = unsafe { core::mem::zeroed() }; // TODO: Elegently handle the empty fd_set without allocating redundant fd_set
let mut zero_fds2: libc::fd_set = unsafe { core::mem::zeroed() }; let mut empty_set_for_read = libc::fd_set::new_empty();
let mut empty_set_for_write = libc::fd_set::new_empty();
let mut empty_set_for_except = libc::fd_set::new_empty();
let readfds = if !readfds.is_null() { let readfds = if !readfds.is_null() {
from_user::check_mut_ptr(readfds)?; from_user::check_mut_ptr(readfds)?;
unsafe { &mut *readfds } unsafe { &mut *readfds }
} else { } else {
&mut zero_fds0 &mut empty_set_for_read
}; };
let writefds = if !writefds.is_null() { let writefds = if !writefds.is_null() {
from_user::check_mut_ptr(writefds)?; from_user::check_mut_ptr(writefds)?;
unsafe { &mut *writefds } unsafe { &mut *writefds }
} else { } else {
&mut zero_fds1 &mut empty_set_for_write
}; };
let exceptfds = if !exceptfds.is_null() { let exceptfds = if !exceptfds.is_null() {
from_user::check_mut_ptr(exceptfds)?; from_user::check_mut_ptr(exceptfds)?;
unsafe { &mut *exceptfds } unsafe { &mut *exceptfds }
} else { } else {
&mut zero_fds2 &mut empty_set_for_except
}; };
let timeout = if !timeout.is_null() {
let timeout_option = if !timeout.is_null() {
from_user::check_ptr(timeout)?; from_user::check_ptr(timeout)?;
Some(unsafe { timeout.read() }) unsafe {
(*timeout).validate()?;
Some(&mut *timeout)
}
} else { } else {
None None
}; };
let n = io_multiplexing::do_select(nfds, readfds, writefds, exceptfds, timeout)?; let ret = io_multiplexing::select(nfds, readfds, writefds, exceptfds, timeout_option)?;
Ok(n as isize) Ok(ret)
} }
pub fn do_poll(fds: *mut libc::pollfd, nfds: libc::nfds_t, timeout: c_int) -> Result<isize> { pub fn do_poll(fds: *mut libc::pollfd, nfds: libc::nfds_t, timeout: c_int) -> Result<isize> {

@ -98,7 +98,7 @@ macro_rules! process_syscall_table_with_callback {
(Writev = 20) => do_writev(fd: FileDesc, iov: *const iovec_t, count: i32), (Writev = 20) => do_writev(fd: FileDesc, iov: *const iovec_t, count: i32),
(Access = 21) => do_access(path: *const i8, mode: u32), (Access = 21) => do_access(path: *const i8, mode: u32),
(Pipe = 22) => do_pipe(fds_u: *mut i32), (Pipe = 22) => do_pipe(fds_u: *mut i32),
(Select = 23) => do_select(nfds: c_int, readfds: *mut libc::fd_set, writefds: *mut libc::fd_set, exceptfds: *mut libc::fd_set, timeout: *const libc::timeval), (Select = 23) => do_select(nfds: c_int, readfds: *mut libc::fd_set, writefds: *mut libc::fd_set, exceptfds: *mut libc::fd_set, timeout: *mut timeval_t),
(SchedYield = 24) => do_sched_yield(), (SchedYield = 24) => do_sched_yield(),
(Mremap = 25) => do_mremap(old_addr: usize, old_size: usize, new_size: usize, flags: i32, new_addr: usize), (Mremap = 25) => do_mremap(old_addr: usize, old_size: usize, new_size: usize, flags: i32, new_addr: usize),
(Msync = 26) => handle_unsupported(), (Msync = 26) => handle_unsupported(),

@ -2,6 +2,7 @@
#define __OCCLUM_EDL_TYPES__ #define __OCCLUM_EDL_TYPES__
#include <time.h> // import struct timespec #include <time.h> // import struct timespec
#include <sys/select.h> // import fd_set
#include <sys/time.h> // import struct timeval #include <sys/time.h> // import struct timeval
#include <sys/uio.h> // import struct iovec #include <sys/uio.h> // import struct iovec
#include <occlum_pal_api.h> // import occlum_stdio_fds #include <occlum_pal_api.h> // import occlum_stdio_fds

@ -1,4 +1,5 @@
#include <sys/types.h> #include <sys/types.h>
#include <sys/select.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <stddef.h> #include <stddef.h>
#include "ocalls.h" #include "ocalls.h"
@ -53,3 +54,11 @@ ssize_t occlum_ocall_recvmsg(int sockfd,
*msg_flags_recv = msg.msg_flags; *msg_flags_recv = msg.msg_flags;
return ret; return ret;
} }
int occlum_ocall_select(int nfds,
fd_set *readfds,
fd_set *writefds,
fd_set *exceptfds,
struct timeval *timeout) {
return select(nfds, readfds, writefds, exceptfds, timeout);
}

@ -145,7 +145,8 @@ int test_read_write() {
} }
int test_select_with_socket() { int test_select_with_socket() {
fd_set wfds; fd_set rfds, wfds;
int ret = 0;
struct timeval tv = { .tv_sec = 60, .tv_usec = 0 }; struct timeval tv = { .tv_sec = 60, .tv_usec = 0 };
@ -155,15 +156,24 @@ int test_select_with_socket() {
THROW_ERROR("failed to create files"); THROW_ERROR("failed to create files");
} }
FD_ZERO(&rfds);
FD_ZERO(&wfds); FD_ZERO(&wfds);
FD_SET(sock, &rfds);
FD_SET(sock, &wfds); FD_SET(sock, &wfds);
FD_SET(event_fd, &rfds);
FD_SET(event_fd, &wfds); FD_SET(event_fd, &wfds);
ret = select(sock > event_fd? sock + 1: event_fd + 1, &rfds, &wfds, NULL, &tv);
if (select(sock > event_fd? sock + 1: event_fd + 1, NULL, &wfds, NULL, &tv) <= 0) { if (ret != 3) {
close_files(2, sock, event_fd); close_files(2, sock, event_fd);
THROW_ERROR("select failed"); THROW_ERROR("select failed");
} }
if (FD_ISSET(event_fd, &rfds) == 1 || FD_ISSET(event_fd, &wfds) == 0 ||
FD_ISSET(sock, &rfds) == 0 || FD_ISSET(sock, &wfds) == 0) {
close_files(2, sock, event_fd);
THROW_ERROR("bad select return");
}
close_files(2, sock, event_fd); close_files(2, sock, event_fd);
return 0; return 0;
} }