diff --git a/src/libos/src/net/io_multiplexing/mod.rs b/src/libos/src/net/io_multiplexing/mod.rs index 0957f236..b04b1d1a 100644 --- a/src/libos/src/net/io_multiplexing/mod.rs +++ b/src/libos/src/net/io_multiplexing/mod.rs @@ -13,7 +13,7 @@ pub use self::io_event::{ }; pub use self::poll::{do_poll, PollEvent, PollEventFlags}; pub use self::poll_new::{do_poll_new, PollFd}; -pub use self::select::{select, FdSetExt}; +pub use self::select::{do_select, FdSetExt}; use fs::{AsDevRandom, AsEvent, CreationFlags, File, FileDesc, FileRef, HostFd, PipeType}; use std::any::Any; diff --git a/src/libos/src/net/io_multiplexing/select.rs b/src/libos/src/net/io_multiplexing/select.rs index 66387efd..8c183875 100644 --- a/src/libos/src/net/io_multiplexing/select.rs +++ b/src/libos/src/net/io_multiplexing/select.rs @@ -1,93 +1,123 @@ -use super::super::time::timer_slack::TIMERSLACK; -use super::*; +use std::time::Duration; -pub fn select( - nfds: c_int, - readfds: &mut libc::fd_set, - writefds: &mut libc::fd_set, - exceptfds: &mut libc::fd_set, - timeout: *mut timeval_t, +use super::poll_new::{do_poll_new, PollFd}; +use crate::fs::IoEvents; +use crate::prelude::*; + +pub fn do_select( + num_fds: FileDesc, + mut readfds: Option<&mut libc::fd_set>, + mut writefds: Option<&mut libc::fd_set>, + mut exceptfds: Option<&mut libc::fd_set>, + timeout: Option<&mut Duration>, ) -> Result { debug!( - "read: {} write: {} exception: {}", + "do_select: read: {}, write: {}, exception: {}, timeout: {:?}", readfds.format(), writefds.format(), - exceptfds.format() + exceptfds.format(), + timeout, ); - let mut ready_num = 0; - let mut pollfds: Vec = Vec::new(); - - for fd in 0..(nfds as FileDesc) { - let (r, w, e) = ( - readfds.is_set(fd), - writefds.is_set(fd), - exceptfds.is_set(fd), - ); - if !(r || w || e) { - continue; - } - - if current!().file(fd).is_err() { - return_errno!( - EBADF, - "An invalid file descriptor was given in one of the sets" - ); - } - - let mut events = PollEventFlags::empty(); - if r { - events |= PollEventFlags::POLLIN; - } - if w { - events |= PollEventFlags::POLLOUT; - } - if e { - events |= PollEventFlags::POLLPRI; - } - - pollfds.push(PollEvent::new(fd, events)); + if num_fds as usize > libc::FD_SETSIZE { + return_errno!(EINVAL, "the value is too large"); } - let mut origin_timeout: timeval_t = if timeout.is_null() { - Default::default() - } else { - unsafe { *timeout } + // Convert the three fd_set's to an array of PollFd + let poll_fds = { + let mut poll_fds = Vec::new(); + for fd in (0..num_fds).into_iter() { + let events = { + let (mut readable, mut writable, mut except) = (false, false, false); + if let Some(readfds) = readfds.as_ref() { + if readfds.is_set(fd) { + readable = true; + } + } + if let Some(writefds) = writefds.as_ref() { + if writefds.is_set(fd) { + writable = true; + } + } + if let Some(exceptfds) = exceptfds.as_ref() { + if exceptfds.is_set(fd) { + except = true; + } + } + convert_rwe_to_events(readable, writable, except) + }; + + if events.is_empty() { + continue; + } + + let poll_fd = PollFd::new(fd, events); + poll_fds.push(poll_fd); + } + poll_fds }; - - let ret = do_poll(&mut pollfds, timeout)?; - - readfds.clear(); - writefds.clear(); - exceptfds.clear(); - - if !timeout.is_null() { - let time_left = unsafe { *(timeout) }; - time_left.validate()?; - assert!( - // Note: TIMERSLACK is a single value use maintained by the libOS and will not vary for different threads. - time_left.as_duration() <= origin_timeout.as_duration() + (*TIMERSLACK).to_duration() - ); + // Clear up the three input fd_set's, which will be used for output as well + if let Some(readfds) = readfds.as_mut() { + readfds.clear(); + } + if let Some(writefds) = writefds.as_mut() { + writefds.clear(); + } + if let Some(exceptfds) = exceptfds.as_mut() { + exceptfds.clear(); } - debug!("returned pollfds are {:?}", pollfds); - for pollfd in &pollfds { - let (r_poll, w_poll, e_poll) = convert_to_readable_writable_exceptional(pollfd.revents()); - if r_poll { - readfds.set(pollfd.fd())?; - ready_num += 1; - } - if w_poll { - writefds.set(pollfd.fd())?; - ready_num += 1; - } - if e_poll { - exceptfds.set(pollfd.fd())?; - ready_num += 1; - } + // Do the poll syscall that is equivalent to the select syscall + let num_ready_fds = do_poll_new(&poll_fds, timeout)?; + if num_ready_fds == 0 { + return Ok(0); } - Ok(ready_num) + // Convert poll's pollfd results to select's fd_set results + let mut num_events = 0; + for poll_fd in &poll_fds { + let fd = poll_fd.fd(); + let revents = poll_fd.revents().get(); + let (readable, writable, exception) = convert_events_to_rwe(&revents); + if readable { + readfds.set(fd); + num_events += 1; + } + if writable { + writefds.set(fd); + num_events += 1; + } + if exception { + exceptfds.set(fd); + num_events += 1; + } + } + Ok(num_events) +} + +// Convert select's rwe input to poll's IoEvents input accordingg to Linux's +// behavior. +fn convert_rwe_to_events(readable: bool, writable: bool, except: bool) -> IoEvents { + let mut events = IoEvents::empty(); + if readable { + events |= IoEvents::IN; + } + if writable { + events |= IoEvents::OUT; + } + if except { + events |= IoEvents::PRI; + } + events +} + +// Convert poll's IoEvents results to select's rwe results according to Linux's +// behavior. +fn convert_events_to_rwe(events: &IoEvents) -> (bool, bool, bool) { + let readable = events.intersects(IoEvents::IN | IoEvents::HUP | IoEvents::ERR); + let writable = events.intersects(IoEvents::OUT | IoEvents::ERR); + let exception = events.contains(IoEvents::PRI); + (readable, writable, exception) } /// Safe methods for `libc::fd_set` @@ -163,20 +193,25 @@ impl FdSetExt for libc::fd_set { } } -// The correspondence is from man2/select.2.html -fn convert_to_readable_writable_exceptional(events: PollEventFlags) -> (bool, bool, bool) { - ( - (PollEventFlags::POLLRDNORM - | PollEventFlags::POLLRDBAND - | PollEventFlags::POLLIN - | PollEventFlags::POLLHUP - | PollEventFlags::POLLERR) - .intersects(events), - (PollEventFlags::POLLWRBAND - | PollEventFlags::POLLWRNORM - | PollEventFlags::POLLOUT - | PollEventFlags::POLLERR) - .intersects(events), - PollEventFlags::POLLPRI.intersects(events), - ) +trait FdSetOptionExt { + fn format(&self) -> String; + fn set(&mut self, fd: FileDesc) -> Result<()>; +} + +impl FdSetOptionExt for Option<&mut libc::fd_set> { + fn format(&self) -> String { + if let Some(self_) = self.as_ref() { + self_.format() + } else { + "(empty)".to_string() + } + } + + fn set(&mut self, fd: FileDesc) -> Result<()> { + if let Some(inner) = self.as_mut() { + inner.set(fd) + } else { + Ok(()) + } + } } diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index 2044c1f0..76936660 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -530,54 +530,58 @@ pub fn do_select( exceptfds: *mut libc::fd_set, timeout: *mut timeval_t, ) -> Result { - // check arguments - let soft_rlimit_nofile = current!() - .rlimits() - .lock() - .unwrap() - .get(resource_t::RLIMIT_NOFILE) - .get_cur(); - if nfds < 0 || nfds > libc::FD_SETSIZE as i32 || nfds as u64 > soft_rlimit_nofile { - return_errno!( - EINVAL, - "nfds is negative or exceeds the resource limit or FD_SETSIZE" - ); - } - - if !timeout.is_null() { - from_user::check_ptr(timeout)?; - unsafe { - (*timeout).validate()?; + let nfds = { + let soft_rlimit_nofile = current!() + .rlimits() + .lock() + .unwrap() + .get(resource_t::RLIMIT_NOFILE) + .get_cur(); + if nfds < 0 || nfds > libc::FD_SETSIZE as i32 || nfds as u64 > soft_rlimit_nofile { + return_errno!( + EINVAL, + "nfds is negative or exceeds the resource limit or FD_SETSIZE" + ); } - } + nfds as FileDesc + }; - // Select handles empty set and null in the same way - // TODO: Elegently handle the empty fd_set without allocating redundant fd_set - let mut empty_set_for_read = libc::fd_set::new_empty(); - let mut empty_set_for_write = libc::fd_set::new_empty(); - let mut empty_set_for_except = libc::fd_set::new_empty(); + let mut timeout_c = if !timeout.is_null() { + from_user::check_ptr(timeout)?; + let timeval = unsafe { &mut *timeout }; + timeval.validate()?; + Some(timeval) + } else { + None + }; + let mut timeout = timeout_c.as_ref().map(|timeout_c| timeout_c.as_duration()); let readfds = if !readfds.is_null() { from_user::check_mut_ptr(readfds)?; - unsafe { &mut *readfds } + Some(unsafe { &mut *readfds }) } else { - &mut empty_set_for_read + None }; let writefds = if !writefds.is_null() { from_user::check_mut_ptr(writefds)?; - unsafe { &mut *writefds } + Some(unsafe { &mut *writefds }) } else { - &mut empty_set_for_write + None }; let exceptfds = if !exceptfds.is_null() { from_user::check_mut_ptr(exceptfds)?; - unsafe { &mut *exceptfds } + Some(unsafe { &mut *exceptfds }) } else { - &mut empty_set_for_except + None }; - let ret = io_multiplexing::select(nfds, readfds, writefds, exceptfds, timeout)?; - Ok(ret) + let ret = io_multiplexing::do_select(nfds, readfds, writefds, exceptfds, timeout.as_mut()); + + if let Some(timeout_c) = timeout_c { + *timeout_c = timeout.unwrap().into(); + } + + ret } pub fn do_poll(fds: *mut libc::pollfd, nfds: libc::nfds_t, timeout_ms: c_int) -> Result { diff --git a/src/libos/src/time/mod.rs b/src/libos/src/time/mod.rs index 631ddaeb..c7141fef 100644 --- a/src/libos/src/time/mod.rs +++ b/src/libos/src/time/mod.rs @@ -53,6 +53,15 @@ impl timeval_t { } } +impl From for timeval_t { + fn from(duration: Duration) -> timeval_t { + let sec = duration.as_secs() as time_t; + let usec = duration.subsec_micros() as i64; + debug_assert!(sec >= 0); // nsec >= 0 always holds + timeval_t { sec, usec } + } +} + pub fn do_gettimeofday() -> timeval_t { extern "C" { fn occlum_ocall_gettimeofday(tv: *mut timeval_t) -> sgx_status_t; diff --git a/test/pipe/main.c b/test/pipe/main.c index 21d8acba..51fa0d4a 100644 --- a/test/pipe/main.c +++ b/test/pipe/main.c @@ -287,7 +287,7 @@ int test_select_read_write() { FD_ZERO(&rfds); FD_SET(pipe_fds[0], &rfds); - if (select(pipe_fds[0] + 1, &rfds, NULL, NULL, NULL) != 1) { + if (select(pipe_fds[0] + 1, &rfds, NULL, NULL, NULL) <= 0) { free_pipe(pipe_fds); THROW_ERROR("select failed"); } @@ -316,13 +316,13 @@ static test_case_t test_cases[] = { TEST_CASE(test_fcntl_get_flags), TEST_CASE(test_fcntl_set_flags), TEST_CASE(test_create_with_flags), - //TEST_CASE(test_select_timeout), + TEST_CASE(test_select_timeout), TEST_CASE(test_poll_timeout), TEST_CASE(test_epoll_timeout), - //TEST_CASE(test_select_no_timeout), + TEST_CASE(test_select_no_timeout), TEST_CASE(test_poll_no_timeout), TEST_CASE(test_epoll_no_timeout), - //TEST_CASE(test_select_read_write), + TEST_CASE(test_select_read_write), }; int main(int argc, const char *argv[]) {