Fix a bug in poll that modifies input fds

This commit is contained in:
He Sun 2020-03-06 13:43:16 +08:00 committed by tate.thl
parent 74fad28938
commit 06f7763d55
2 changed files with 69 additions and 26 deletions

@ -109,53 +109,73 @@ pub fn do_select(
Ok(ret as usize) Ok(ret as usize)
} }
pub fn do_poll(polls: &mut [libc::pollfd], timeout: c_int) -> Result<usize> { pub fn do_poll(pollfds: &mut [libc::pollfd], timeout: c_int) -> Result<usize> {
info!( info!(
"poll: {:?}, timeout: {}", "poll: {:?}, timeout: {}",
polls.iter().map(|p| p.fd).collect::<Vec<_>>(), pollfds.iter().map(|p| p.fd).collect::<Vec<_>>(),
timeout timeout
); );
let current_ref = process::get_current(); let current_ref = process::get_current();
let mut proc = current_ref.lock().unwrap(); let mut proc = current_ref.lock().unwrap();
// convert libos fd to Linux fd // Untrusted pollfd's that will be modified by OCall
for poll in polls.iter_mut() { let mut u_pollfds: Vec<libc::pollfd> = pollfds.to_vec();
let file_ref = proc.get_files().lock().unwrap().get(poll.fd as FileDesc)?;
for (i, pollfd) in pollfds.iter_mut().enumerate() {
let file_ref = proc
.get_files()
.lock()
.unwrap()
.get(pollfd.fd as FileDesc)?;
if let Ok(socket) = file_ref.as_socket() { if let Ok(socket) = file_ref.as_socket() {
poll.fd = socket.fd(); // convert libos fd to host fd in the copy to keep pollfds unchanged
u_pollfds[i].fd = socket.fd();
u_pollfds[i].revents = 0;
} else if let Ok(socket) = file_ref.as_unix_socket() { } else if let Ok(socket) = file_ref.as_unix_socket() {
// FIXME: spin poll until can read (hack for php) // FIXME: spin poll until can read (hack for php)
while (poll.events & libc::POLLIN) != 0 && socket.poll()?.0 == false { while (pollfd.events & libc::POLLIN) != 0 && socket.poll()?.0 == false {
spin_loop_hint(); spin_loop_hint();
} }
let (r, w, e) = socket.poll()?; let (r, w, e) = socket.poll()?;
if r { if r {
poll.revents |= libc::POLLIN; pollfd.revents |= libc::POLLIN;
} }
if w { if w {
poll.revents |= libc::POLLOUT; pollfd.revents |= libc::POLLOUT;
} }
pollfd.revents &= pollfd.events;
if e { if e {
poll.revents |= libc::POLLERR; pollfd.revents |= libc::POLLERR;
} }
poll.revents &= poll.events;
warn!("poll unix socket is unimplemented, spin for read"); warn!("poll unix socket is unimplemented, spin for read");
return Ok(1); return Ok(1);
} else if let Ok(dev_random) = file_ref.as_dev_random() { } else if let Ok(dev_random) = file_ref.as_dev_random() {
return Ok(dev_random.poll(poll)?); return Ok(dev_random.poll(pollfd)?);
} else { } else {
return_errno!(EBADF, "not a socket"); return_errno!(EBADF, "not a supported file type");
} }
} }
let ret = try_libc!(libc::ocall::poll(
polls.as_mut_ptr(), let num_events = try_libc!(libc::ocall::poll(
polls.len() as u64, u_pollfds.as_mut_ptr(),
u_pollfds.len() as u64,
timeout timeout
)); )) as usize;
// recover fd ? assert!(num_events <= pollfds.len());
Ok(ret as usize)
// Copy back revents from the untrusted pollfds
let mut num_nonzero_revents = 0;
for (i, pollfd) in pollfds.iter_mut().enumerate() {
if u_pollfds[i].revents == 0 {
continue;
}
pollfd.revents = u_pollfds[i].revents;
num_nonzero_revents += 1;
}
assert!(num_nonzero_revents == num_events);
Ok(num_events as usize)
} }
pub fn do_epoll_create1(flags: c_int) -> Result<FileDesc> { pub fn do_epoll_create1(flags: c_int) -> Result<FileDesc> {

@ -1,15 +1,16 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h> #include <errno.h>
#include <fcntl.h> #include <fcntl.h>
#include <poll.h>
#include <spawn.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/wait.h> #include <sys/wait.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <spawn.h>
#include <unistd.h>
#include "test.h" #include "test.h"
@ -289,12 +290,34 @@ int test_fcntl_setfl_and_getfl() {
return ret; return ret;
} }
int test_poll_sockets() {
int socks[2], ret;
socks[0] = socket(AF_INET, SOCK_STREAM, 0);
socks[1] = socket(AF_INET, SOCK_STREAM, 0);
struct pollfd pollfds[] = {
{ .fd = socks[0], .events = POLLIN },
{ .fd = socks[1], .events = POLLIN },
};
ret = poll(pollfds, 2, 0);
if (ret < 0)
THROW_ERROR("poll error");
if (pollfds[0].fd != socks[0] ||
pollfds[0].events != POLLIN ||
pollfds[1].fd != socks[1] ||
pollfds[1].events != POLLIN)
THROW_ERROR("fd and events of pollfd should remain unchanged");
return 0;
}
static test_case_t test_cases[] = { static test_case_t test_cases[] = {
TEST_CASE(test_read_write), TEST_CASE(test_read_write),
TEST_CASE(test_send_recv), TEST_CASE(test_send_recv),
TEST_CASE(test_sendmsg_recvmsg), TEST_CASE(test_sendmsg_recvmsg),
TEST_CASE(test_sendmsg_recvmsg_connectionless), TEST_CASE(test_sendmsg_recvmsg_connectionless),
TEST_CASE(test_fcntl_setfl_and_getfl), TEST_CASE(test_fcntl_setfl_and_getfl),
TEST_CASE(test_poll_sockets),
}; };
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {