diff --git a/src/libos/src/net/io_multiplexing.rs b/src/libos/src/net/io_multiplexing.rs index 6530a1c1..48ed7fce 100644 --- a/src/libos/src/net/io_multiplexing.rs +++ b/src/libos/src/net/io_multiplexing.rs @@ -109,53 +109,73 @@ pub fn do_select( Ok(ret as usize) } -pub fn do_poll(polls: &mut [libc::pollfd], timeout: c_int) -> Result { +pub fn do_poll(pollfds: &mut [libc::pollfd], timeout: c_int) -> Result { info!( "poll: {:?}, timeout: {}", - polls.iter().map(|p| p.fd).collect::>(), + pollfds.iter().map(|p| p.fd).collect::>(), timeout ); let current_ref = process::get_current(); let mut proc = current_ref.lock().unwrap(); - // convert libos fd to Linux fd - for poll in polls.iter_mut() { - let file_ref = proc.get_files().lock().unwrap().get(poll.fd as FileDesc)?; + // Untrusted pollfd's that will be modified by OCall + let mut u_pollfds: Vec = pollfds.to_vec(); + + for (i, pollfd) in pollfds.iter_mut().enumerate() { + let file_ref = proc + .get_files() + .lock() + .unwrap() + .get(pollfd.fd as FileDesc)?; if let Ok(socket) = file_ref.as_socket() { - poll.fd = socket.fd(); + // convert libos fd to host fd in the copy to keep pollfds unchanged + u_pollfds[i].fd = socket.fd(); + u_pollfds[i].revents = 0; } else if let Ok(socket) = file_ref.as_unix_socket() { // FIXME: spin poll until can read (hack for php) - while (poll.events & libc::POLLIN) != 0 && socket.poll()?.0 == false { + while (pollfd.events & libc::POLLIN) != 0 && socket.poll()?.0 == false { spin_loop_hint(); } let (r, w, e) = socket.poll()?; if r { - poll.revents |= libc::POLLIN; + pollfd.revents |= libc::POLLIN; } if w { - poll.revents |= libc::POLLOUT; + pollfd.revents |= libc::POLLOUT; } + pollfd.revents &= pollfd.events; if e { - poll.revents |= libc::POLLERR; + pollfd.revents |= libc::POLLERR; } - poll.revents &= poll.events; warn!("poll unix socket is unimplemented, spin for read"); return Ok(1); } else if let Ok(dev_random) = file_ref.as_dev_random() { - return Ok(dev_random.poll(poll)?); + return Ok(dev_random.poll(pollfd)?); } else { - return_errno!(EBADF, "not a socket"); + return_errno!(EBADF, "not a supported file type"); } } - let ret = try_libc!(libc::ocall::poll( - polls.as_mut_ptr(), - polls.len() as u64, + + let num_events = try_libc!(libc::ocall::poll( + u_pollfds.as_mut_ptr(), + u_pollfds.len() as u64, timeout - )); - // recover fd ? - Ok(ret as usize) + )) as usize; + assert!(num_events <= pollfds.len()); + + // Copy back revents from the untrusted pollfds + let mut num_nonzero_revents = 0; + for (i, pollfd) in pollfds.iter_mut().enumerate() { + if u_pollfds[i].revents == 0 { + continue; + } + pollfd.revents = u_pollfds[i].revents; + num_nonzero_revents += 1; + } + assert!(num_nonzero_revents == num_events); + Ok(num_events as usize) } pub fn do_epoll_create1(flags: c_int) -> Result { diff --git a/test/server/main.c b/test/server/main.c index 4fd47800..8f5908ee 100644 --- a/test/server/main.c +++ b/test/server/main.c @@ -1,15 +1,16 @@ -#include -#include -#include #include #include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include -#include -#include -#include -#include #include "test.h" @@ -289,12 +290,34 @@ int test_fcntl_setfl_and_getfl() { return ret; } +int test_poll_sockets() { + int socks[2], ret; + socks[0] = socket(AF_INET, SOCK_STREAM, 0); + socks[1] = socket(AF_INET, SOCK_STREAM, 0); + struct pollfd pollfds[] = { + { .fd = socks[0], .events = POLLIN }, + { .fd = socks[1], .events = POLLIN }, + }; + + ret = poll(pollfds, 2, 0); + if (ret < 0) + THROW_ERROR("poll error"); + + if (pollfds[0].fd != socks[0] || + pollfds[0].events != POLLIN || + pollfds[1].fd != socks[1] || + pollfds[1].events != POLLIN) + THROW_ERROR("fd and events of pollfd should remain unchanged"); + return 0; +} + static test_case_t test_cases[] = { TEST_CASE(test_read_write), TEST_CASE(test_send_recv), TEST_CASE(test_sendmsg_recvmsg), TEST_CASE(test_sendmsg_recvmsg_connectionless), TEST_CASE(test_fcntl_setfl_and_getfl), + TEST_CASE(test_poll_sockets), }; int main(int argc, const char* argv[]) {