Fix a bug in poll that modifies input fds

This commit is contained in:
He Sun 2020-03-06 13:43:16 +08:00 committed by tate.thl
parent 74fad28938
commit 06f7763d55
2 changed files with 69 additions and 26 deletions

@ -109,53 +109,73 @@ pub fn do_select(
Ok(ret as usize)
}
pub fn do_poll(polls: &mut [libc::pollfd], timeout: c_int) -> Result<usize> {
pub fn do_poll(pollfds: &mut [libc::pollfd], timeout: c_int) -> Result<usize> {
info!(
"poll: {:?}, timeout: {}",
polls.iter().map(|p| p.fd).collect::<Vec<_>>(),
pollfds.iter().map(|p| p.fd).collect::<Vec<_>>(),
timeout
);
let current_ref = process::get_current();
let mut proc = current_ref.lock().unwrap();
// convert libos fd to Linux fd
for poll in polls.iter_mut() {
let file_ref = proc.get_files().lock().unwrap().get(poll.fd as FileDesc)?;
// Untrusted pollfd's that will be modified by OCall
let mut u_pollfds: Vec<libc::pollfd> = pollfds.to_vec();
for (i, pollfd) in pollfds.iter_mut().enumerate() {
let file_ref = proc
.get_files()
.lock()
.unwrap()
.get(pollfd.fd as FileDesc)?;
if let Ok(socket) = file_ref.as_socket() {
poll.fd = socket.fd();
// convert libos fd to host fd in the copy to keep pollfds unchanged
u_pollfds[i].fd = socket.fd();
u_pollfds[i].revents = 0;
} else if let Ok(socket) = file_ref.as_unix_socket() {
// FIXME: spin poll until can read (hack for php)
while (poll.events & libc::POLLIN) != 0 && socket.poll()?.0 == false {
while (pollfd.events & libc::POLLIN) != 0 && socket.poll()?.0 == false {
spin_loop_hint();
}
let (r, w, e) = socket.poll()?;
if r {
poll.revents |= libc::POLLIN;
pollfd.revents |= libc::POLLIN;
}
if w {
poll.revents |= libc::POLLOUT;
pollfd.revents |= libc::POLLOUT;
}
pollfd.revents &= pollfd.events;
if e {
poll.revents |= libc::POLLERR;
pollfd.revents |= libc::POLLERR;
}
poll.revents &= poll.events;
warn!("poll unix socket is unimplemented, spin for read");
return Ok(1);
} else if let Ok(dev_random) = file_ref.as_dev_random() {
return Ok(dev_random.poll(poll)?);
return Ok(dev_random.poll(pollfd)?);
} else {
return_errno!(EBADF, "not a socket");
return_errno!(EBADF, "not a supported file type");
}
}
let ret = try_libc!(libc::ocall::poll(
polls.as_mut_ptr(),
polls.len() as u64,
let num_events = try_libc!(libc::ocall::poll(
u_pollfds.as_mut_ptr(),
u_pollfds.len() as u64,
timeout
));
// recover fd ?
Ok(ret as usize)
)) as usize;
assert!(num_events <= pollfds.len());
// Copy back revents from the untrusted pollfds
let mut num_nonzero_revents = 0;
for (i, pollfd) in pollfds.iter_mut().enumerate() {
if u_pollfds[i].revents == 0 {
continue;
}
pollfd.revents = u_pollfds[i].revents;
num_nonzero_revents += 1;
}
assert!(num_nonzero_revents == num_events);
Ok(num_events as usize)
}
pub fn do_epoll_create1(flags: c_int) -> Result<FileDesc> {

@ -1,15 +1,16 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>
#include <poll.h>
#include <spawn.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <spawn.h>
#include <unistd.h>
#include "test.h"
@ -289,12 +290,34 @@ int test_fcntl_setfl_and_getfl() {
return ret;
}
int test_poll_sockets() {
int socks[2], ret;
socks[0] = socket(AF_INET, SOCK_STREAM, 0);
socks[1] = socket(AF_INET, SOCK_STREAM, 0);
struct pollfd pollfds[] = {
{ .fd = socks[0], .events = POLLIN },
{ .fd = socks[1], .events = POLLIN },
};
ret = poll(pollfds, 2, 0);
if (ret < 0)
THROW_ERROR("poll error");
if (pollfds[0].fd != socks[0] ||
pollfds[0].events != POLLIN ||
pollfds[1].fd != socks[1] ||
pollfds[1].events != POLLIN)
THROW_ERROR("fd and events of pollfd should remain unchanged");
return 0;
}
static test_case_t test_cases[] = {
TEST_CASE(test_read_write),
TEST_CASE(test_send_recv),
TEST_CASE(test_sendmsg_recvmsg),
TEST_CASE(test_sendmsg_recvmsg_connectionless),
TEST_CASE(test_fcntl_setfl_and_getfl),
TEST_CASE(test_poll_sockets),
};
int main(int argc, const char* argv[]) {