From c85163ec0a6ac6f56ae80539cdf336fd3221b19d Mon Sep 17 00:00:00 2001 From: He Sun Date: Mon, 6 Jul 2020 12:56:29 +0800 Subject: [PATCH] Add notification mechanism for basic IO events 1. Add notification mechanism for select, poll, epoll and blocking IO 2. Add pipe support for select, poll and blocking IO --- .gitmodules | 3 + Makefile | 1 + deps/ringbuf | 1 + deps/ringbuf.patch | 40 ++ src/Enclave.edl | 13 +- src/libos/Cargo.lock | 8 + src/libos/Cargo.toml | 1 + src/libos/src/fs/dev_fs/dev_random.rs | 21 +- src/libos/src/fs/file.rs | 12 + src/libos/src/fs/mod.rs | 9 +- src/libos/src/fs/pipe.rs | 197 +++--- src/libos/src/lib.rs | 2 + src/libos/src/net/io_multiplexing/io_event.rs | 54 ++ src/libos/src/net/io_multiplexing/mod.rs | 8 +- src/libos/src/net/io_multiplexing/poll.rs | 255 ++++++-- src/libos/src/net/io_multiplexing/select.rs | 230 +++---- src/libos/src/net/mod.rs | 5 +- src/libos/src/net/syscalls.rs | 34 +- src/libos/src/net/unix_socket.rs | 105 ++-- src/libos/src/process/thread/mod.rs | 20 + src/libos/src/syscall/mod.rs | 8 +- src/libos/src/time/mod.rs | 7 + src/libos/src/util/ring_buf.rs | 568 ++++++++++++------ src/pal/src/ocalls/net.c | 40 +- test/pipe/main.c | 123 +++- test/server/main.c | 43 +- test/unix_socket/main.c | 23 + 27 files changed, 1253 insertions(+), 578 deletions(-) create mode 160000 deps/ringbuf create mode 100644 deps/ringbuf.patch create mode 100644 src/libos/src/net/io_multiplexing/io_event.rs diff --git a/.gitmodules b/.gitmodules index 6a40b329..6cc60b7a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -22,3 +22,6 @@ path = deps/grpc-rust url = https://github.com/stepancheg/grpc-rust.git branch = v0.7 +[submodule "deps/ringbuf"] + path = deps/ringbuf + url = https://github.com/agerasev/ringbuf.git diff --git a/Makefile b/Makefile index 4cd383d8..ad4c9344 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,7 @@ submodule: githooks @# Try to apply the patches. If failed, check if the patches are already applied cd deps/rust-sgx-sdk && git apply ../rust-sgx-sdk.patch >/dev/null 2>&1 || git apply ../rust-sgx-sdk.patch -R --check cd deps/serde-json-sgx && git apply ../serde-json-sgx.patch >/dev/null 2>&1 || git apply ../serde-json-sgx.patch -R --check + cd deps/ringbuf && git apply ../ringbuf.patch >/dev/null 2>&1 || git apply ../ringbuf.patch -R --check @# Enclaves used by tools are running in simulation mode by default to run faster. @rm -rf build build_sim diff --git a/deps/ringbuf b/deps/ringbuf new file mode 160000 index 00000000..b8f40358 --- /dev/null +++ b/deps/ringbuf @@ -0,0 +1 @@ +Subproject commit b8f403584c2adbf1f1d78594ece36f2bf144c095 diff --git a/deps/ringbuf.patch b/deps/ringbuf.patch new file mode 100644 index 00000000..9d03bbcc --- /dev/null +++ b/deps/ringbuf.patch @@ -0,0 +1,40 @@ +diff --git a/Cargo.toml b/Cargo.toml +index 92b7e5a..d41b5af 100644 +--- a/Cargo.toml ++++ b/Cargo.toml +@@ -17,4 +17,8 @@ license = "MIT/Apache-2.0" + default = [] + benchmark = [] + ++[patch.'https://github.com/apache/teaclave-sgx-sdk.git'] ++sgx_tstd = { path = "../rust-sgx-sdk/sgx_tstd" } ++ + [dependencies] ++sgx_tstd = { path = "../rust-sgx-sdk/sgx_tstd", features = ["backtrace"] } +diff --git a/src/lib.rs b/src/lib.rs +index 5b45f90..6ec90f1 100644 +--- a/src/lib.rs ++++ b/src/lib.rs +@@ -116,6 +116,10 @@ + + #![cfg_attr(feature = "benchmark", feature(test))] + ++#![no_std] ++#[macro_use] ++extern crate sgx_tstd as std; ++ + #[cfg(feature = "benchmark")] + extern crate test; + +diff --git a/src/ring_buffer.rs b/src/ring_buffer.rs +index 8ae68af..aa4fb28 100644 +--- a/src/ring_buffer.rs ++++ b/src/ring_buffer.rs +@@ -7,6 +7,7 @@ use std::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, ++ vec::Vec, + }; + + use crate::{consumer::Consumer, producer::Producer}; diff --git a/src/Enclave.edl b/src/Enclave.edl index c1ada329..4886176b 100644 --- a/src/Enclave.edl +++ b/src/Enclave.edl @@ -185,13 +185,12 @@ enclave { int flags ) propagate_errno; - int occlum_ocall_select( - int nfds, - [in, out] fd_set *readfds, - [in, out] fd_set *writefds, - [in, out] fd_set *exceptfds, - [in, out] struct timeval *timeout - ) propagate_errno; + int occlum_ocall_poll( + [in, out, count=nfds] struct pollfd *fds, + nfds_t nfds, + [in, out] struct timeval *timeout, + int efd + )propagate_errno; void occlum_ocall_print_log(uint32_t level, [in, string] const char* msg); void occlum_ocall_flush_log(void); diff --git a/src/libos/Cargo.lock b/src/libos/Cargo.lock index be5764fe..fd41f318 100644 --- a/src/libos/Cargo.lock +++ b/src/libos/Cargo.lock @@ -13,6 +13,7 @@ dependencies = [ "rcore-fs-mountfs", "rcore-fs-ramfs", "rcore-fs-sefs", + "ringbuf", "serde", "serde_json", "sgx_tcrypto", @@ -389,6 +390,13 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "ringbuf" +version = "0.2.1" +dependencies = [ + "sgx_tstd", +] + [[package]] name = "ryu" version = "1.0.4" diff --git a/src/libos/Cargo.toml b/src/libos/Cargo.toml index 7d50d6e3..68595353 100644 --- a/src/libos/Cargo.toml +++ b/src/libos/Cargo.toml @@ -13,6 +13,7 @@ bitvec = { version = "0.17", default-features = false, features = ["alloc"] } log = "0.4" lazy_static = { version = "1.1.0", features = ["spin_no_std"] } # Implies nightly derive_builder = "0.7.2" +ringbuf = { path = "../../deps/ringbuf" } rcore-fs = { path = "../../deps/sefs/rcore-fs" } rcore-fs-sefs = { path = "../../deps/sefs/rcore-fs-sefs" } rcore-fs-ramfs = { path = "../../deps/sefs/rcore-fs-ramfs" } diff --git a/src/libos/src/fs/dev_fs/dev_random.rs b/src/libos/src/fs/dev_fs/dev_random.rs index 96a43a39..85fd5a6a 100644 --- a/src/libos/src/fs/dev_fs/dev_random.rs +++ b/src/libos/src/fs/dev_fs/dev_random.rs @@ -1,4 +1,5 @@ use super::*; +use crate::net::PollEventFlags; #[derive(Debug)] pub struct DevRandom; @@ -59,27 +60,15 @@ impl File for DevRandom { }) } + fn poll(&self) -> Result<(PollEventFlags)> { + Ok(PollEventFlags::POLLIN) + } + fn as_any(&self) -> &dyn Any { self } } -impl DevRandom { - pub fn poll(&self, fd: &mut libc::pollfd) -> Result { - // Just support POLLIN event, because the device is read-only currently - let (num, revents_option) = if (fd.events & libc::POLLIN) != 0 { - (1, Some(libc::POLLIN)) - } else { - // Device is not ready - (0, None) - }; - if let Some(revents) = revents_option { - fd.revents = revents; - } - Ok(num) - } -} - pub trait AsDevRandom { fn as_dev_random(&self) -> Result<&DevRandom>; } diff --git a/src/libos/src/fs/file.rs b/src/libos/src/fs/file.rs index 2d789bf8..f9dbab17 100644 --- a/src/libos/src/fs/file.rs +++ b/src/libos/src/fs/file.rs @@ -91,6 +91,18 @@ pub trait File: Debug + Sync + Send + Any { return_op_unsupported_error!("set_advisory_lock") } + fn poll(&self) -> Result<(crate::net::PollEventFlags)> { + return_op_unsupported_error!("poll") + } + + fn enqueue_event(&self, _: crate::net::IoEvent) -> Result<()> { + return_op_unsupported_error!("enqueue_event"); + } + + fn dequeue_event(&self) -> Result<()> { + return_op_unsupported_error!("dequeue_event"); + } + fn as_any(&self) -> &dyn Any; } diff --git a/src/libos/src/fs/mod.rs b/src/libos/src/fs/mod.rs index 6be7bb09..16b79b7d 100644 --- a/src/libos/src/fs/mod.rs +++ b/src/libos/src/fs/mod.rs @@ -11,17 +11,16 @@ use std::path::Path; use untrusted::{SliceAsMutPtrAndLen, SliceAsPtrAndLen}; pub use self::dev_fs::AsDevRandom; -pub use self::event_file::{AsEvent, EventFile}; +pub use self::event_file::{AsEvent, EventCreationFlags, EventFile}; pub use self::file::{File, FileRef}; pub use self::file_ops::{ - occlum_ocall_ioctl, AccessMode, CreationFlags, FileMode, Stat, StatusFlags, + occlum_ocall_ioctl, AccessMode, CreationFlags, FileMode, Flock, FlockType, IoctlCmd, Stat, + StatusFlags, StructuredIoctlArgType, StructuredIoctlNum, }; -pub use self::file_ops::{Flock, FlockType}; -pub use self::file_ops::{IoctlCmd, StructuredIoctlArgType, StructuredIoctlNum}; pub use self::file_table::{FileDesc, FileTable}; pub use self::fs_view::FsView; pub use self::inode_file::{AsINodeFile, INodeExt, INodeFile}; -pub use self::pipe::Pipe; +pub use self::pipe::PipeType; pub use self::rootfs::ROOT_INODE; pub use self::stdio::{HostStdioFds, StdinFile, StdoutFile}; pub use self::syscalls::*; diff --git a/src/libos/src/fs/pipe.rs b/src/libos/src/fs/pipe.rs index 7cf1f212..918b2903 100644 --- a/src/libos/src/fs/pipe.rs +++ b/src/libos/src/fs/pipe.rs @@ -1,37 +1,34 @@ use super::*; +use net::{IoEvent, PollEventFlags}; use util::ring_buf::*; -// TODO: Use Waiter and WaitQueue infrastructure to sleep when blocking // TODO: Add F_SETPIPE_SZ in fcntl to dynamically change the size of pipe // to improve memory efficiency. This value is got from /proc/sys/fs/pipe-max-size on linux. pub const PIPE_BUF_SIZE: usize = 1024 * 1024; -#[derive(Debug)] -pub struct Pipe { - pub reader: PipeReader, - pub writer: PipeWriter, -} +pub fn pipe(flags: StatusFlags) -> Result<(PipeReader, PipeWriter)> { + let (buffer_reader, buffer_writer) = + ring_buffer(PIPE_BUF_SIZE).map_err(|e| errno!(ENFILE, "No memory for new pipes"))?; + // Only O_NONBLOCK and O_DIRECT can be applied during pipe creation + let valid_flags = flags & (StatusFlags::O_NONBLOCK | StatusFlags::O_DIRECT); -impl Pipe { - pub fn new(flags: StatusFlags) -> Result { - let mut ring_buf = - RingBuf::new(PIPE_BUF_SIZE).map_err(|e| errno!(ENFILE, "No memory for new pipes"))?; - // Only O_NONBLOCK and O_DIRECT can be applied during pipe creation - let valid_flags = flags & (StatusFlags::O_NONBLOCK | StatusFlags::O_DIRECT); - Ok(Pipe { - reader: PipeReader { - inner: SgxMutex::new(ring_buf.reader), - status_flags: RwLock::new(valid_flags), - }, - writer: PipeWriter { - inner: SgxMutex::new(ring_buf.writer), - status_flags: RwLock::new(valid_flags), - }, - }) + if flags.contains(StatusFlags::O_NONBLOCK) { + buffer_reader.set_non_blocking(); + buffer_writer.set_non_blocking(); } + + Ok(( + PipeReader { + inner: SgxMutex::new(buffer_reader), + status_flags: RwLock::new(valid_flags), + }, + PipeWriter { + inner: SgxMutex::new(buffer_writer), + status_flags: RwLock::new(valid_flags), + }, + )) } -#[derive(Debug)] pub struct PipeReader { inner: SgxMutex, status_flags: RwLock, @@ -39,32 +36,13 @@ pub struct PipeReader { impl File for PipeReader { fn read(&self, buf: &mut [u8]) -> Result { - let ringbuf = self.inner.lock().unwrap(); - ringbuf.read(buf) + let mut ringbuf = self.inner.lock().unwrap(); + ringbuf.read_from_buffer(buf) } fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { let mut ringbuf = self.inner.lock().unwrap(); - let mut total_bytes = 0; - for buf in bufs { - match ringbuf.read(buf) { - Ok(this_len) => { - total_bytes += this_len; - if this_len < buf.len() { - break; - } - } - Err(e) => { - match total_bytes { - // a complete failure - 0 => return Err(e), - // a partially failure - _ => break, - } - } - } - } - Ok(total_bytes) + ringbuf.read_from_vector(bufs) } fn get_access_mode(&self) -> Result { @@ -81,9 +59,39 @@ impl File for PipeReader { // Only O_NONBLOCK, O_ASYNC and O_DIRECT can be set *status_flags = new_status_flags & (StatusFlags::O_NONBLOCK | StatusFlags::O_ASYNC | StatusFlags::O_DIRECT); + + if new_status_flags.contains(StatusFlags::O_NONBLOCK) { + self.inner.lock().unwrap().set_non_blocking(); + } else { + self.inner.lock().unwrap().set_blocking(); + } Ok(()) } + fn poll(&self) -> Result { + let ringbuf_reader = self.inner.lock().unwrap(); + let mut events = PollEventFlags::empty(); + if ringbuf_reader.can_read() { + events |= PollEventFlags::POLLIN | PollEventFlags::POLLRDNORM; + } + + if ringbuf_reader.is_peer_closed() { + events |= PollEventFlags::POLLHUP; + } + + Ok(events) + } + + fn enqueue_event(&self, event: IoEvent) -> Result<()> { + let ringbuf_reader = self.inner.lock().unwrap(); + ringbuf_reader.enqueue_event(event) + } + + fn dequeue_event(&self) -> Result<()> { + let ringbuf_reader = self.inner.lock().unwrap(); + ringbuf_reader.dequeue_event() + } + fn as_any(&self) -> &dyn Any { self } @@ -92,7 +100,6 @@ impl File for PipeReader { unsafe impl Send for PipeReader {} unsafe impl Sync for PipeReader {} -#[derive(Debug)] pub struct PipeWriter { inner: SgxMutex, status_flags: RwLock, @@ -100,32 +107,13 @@ pub struct PipeWriter { impl File for PipeWriter { fn write(&self, buf: &[u8]) -> Result { - let ringbuf = self.inner.lock().unwrap(); - ringbuf.write(buf) + let mut ringbuf = self.inner.lock().unwrap(); + ringbuf.write_to_buffer(buf) } fn writev(&self, bufs: &[&[u8]]) -> Result { - let ringbuf = self.inner.lock().unwrap(); - let mut total_bytes = 0; - for buf in bufs { - match ringbuf.write(buf) { - Ok(this_len) => { - total_bytes += this_len; - if this_len < buf.len() { - break; - } - } - Err(e) => { - match total_bytes { - // a complete failure - 0 => return Err(e), - // a partially failure - _ => break, - } - } - } - } - Ok(total_bytes) + let mut ringbuf = self.inner.lock().unwrap(); + ringbuf.write_to_vector(bufs) } fn seek(&self, pos: SeekFrom) -> Result { @@ -146,14 +134,59 @@ impl File for PipeWriter { // Only O_NONBLOCK, O_ASYNC and O_DIRECT can be set *status_flags = new_status_flags & (StatusFlags::O_NONBLOCK | StatusFlags::O_ASYNC | StatusFlags::O_DIRECT); + + if new_status_flags.contains(StatusFlags::O_NONBLOCK) { + self.inner.lock().unwrap().set_non_blocking(); + } else { + self.inner.lock().unwrap().set_blocking(); + } Ok(()) } + fn poll(&self) -> Result { + let ringbuf_writer = self.inner.lock().unwrap(); + let mut events = PollEventFlags::empty(); + if ringbuf_writer.can_write() { + events |= PollEventFlags::POLLOUT | PollEventFlags::POLLWRNORM; + } + if ringbuf_writer.is_peer_closed() { + events |= PollEventFlags::POLLERR; + } + + Ok(events) + } + + fn enqueue_event(&self, event: IoEvent) -> Result<()> { + let ringbuf_writer = self.inner.lock().unwrap(); + ringbuf_writer.enqueue_event(event) + } + + fn dequeue_event(&self) -> Result<()> { + let ringbuf_writer = self.inner.lock().unwrap(); + ringbuf_writer.dequeue_event() + } + fn as_any(&self) -> &dyn Any { self } } +impl fmt::Debug for PipeReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PipeReader") + .field("status_flags", &self.status_flags) + .finish() + } +} + +impl fmt::Debug for PipeWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PipeWriter") + .field("status_flags", &self.status_flags) + .finish() + } +} + unsafe impl Send for PipeWriter {} unsafe impl Sync for PipeWriter {} @@ -162,11 +195,29 @@ pub fn do_pipe2(flags: u32) -> Result<[FileDesc; 2]> { let status_flags = StatusFlags::from_bits_truncate(flags); debug!("pipe2: flags: {:?} {:?}", creation_flags, status_flags); - let current = current!(); - let pipe = Pipe::new(status_flags)?; + let (pipe_reader, pipe_writer) = pipe(status_flags)?; let close_on_spawn = creation_flags.must_close_on_spawn(); - let reader_fd = current.add_file(Arc::new(Box::new(pipe.reader)), close_on_spawn); - let writer_fd = current.add_file(Arc::new(Box::new(pipe.writer)), close_on_spawn); + + let current = current!(); + let reader_fd = current.add_file(Arc::new(Box::new(pipe_reader)), close_on_spawn); + let writer_fd = current.add_file(Arc::new(Box::new(pipe_writer)), close_on_spawn); trace!("pipe2: reader_fd: {}, writer_fd: {}", reader_fd, writer_fd); Ok([reader_fd, writer_fd]) } + +pub trait PipeType { + fn as_pipe_reader(&self) -> Result<&PipeReader>; + fn as_pipe_writer(&self) -> Result<&PipeWriter>; +} +impl PipeType for FileRef { + fn as_pipe_reader(&self) -> Result<&PipeReader> { + self.as_any() + .downcast_ref::() + .ok_or_else(|| errno!(EBADF, "not a pipe reader")) + } + fn as_pipe_writer(&self) -> Result<&PipeWriter> { + self.as_any() + .downcast_ref::() + .ok_or_else(|| errno!(EBADF, "not a pipe writer")) + } +} diff --git a/src/libos/src/lib.rs b/src/libos/src/lib.rs index 797c8354..3a1f3eab 100644 --- a/src/libos/src/lib.rs +++ b/src/libos/src/lib.rs @@ -15,6 +15,7 @@ #![feature(negative_impls)] // for may_dangle in rw_lock #![feature(dropck_eyepatch)] +#![feature(option_expect_none)] #[macro_use] extern crate alloc; @@ -40,6 +41,7 @@ extern crate rcore_fs_ramfs; extern crate rcore_fs_sefs; #[macro_use] extern crate derive_builder; +extern crate ringbuf; extern crate serde; extern crate serde_json; diff --git a/src/libos/src/net/io_multiplexing/io_event.rs b/src/libos/src/net/io_multiplexing/io_event.rs new file mode 100644 index 00000000..11511452 --- /dev/null +++ b/src/libos/src/net/io_multiplexing/io_event.rs @@ -0,0 +1,54 @@ +use super::*; +use crate::fs::EventFile; + +lazy_static! { + pub static ref THREAD_NOTIFIERS: SgxMutex> = + SgxMutex::new(HashMap::new()); +} + +#[derive(Debug)] +pub enum IoEvent { + Poll(PollEvent), + Epoll(EpollEvent), + BlockingRead, + BlockingWrite, +} + +pub fn notify_thread(tid: pid_t) -> Result<()> { + debug!("notify thread {}", tid); + assert_ne!( + tid, + current!().tid(), + "a waiting thread cannot run other programs" + ); + let data: &[u8] = &[1, 0, 0, 0, 0, 0, 0, 0]; + + THREAD_NOTIFIERS + .lock() + .unwrap() + .get(&tid) + .unwrap() + .write(&data)?; + Ok(()) +} + +pub fn clear_notifier_status(tid: pid_t) -> Result<()> { + // One can only clear self for now + assert_eq!(tid, current!().tid()); + debug!("clear thread {} notifier", tid); + let mut data: &mut [u8] = &mut [0; 8]; + + // Ignore the error for no data to read + THREAD_NOTIFIERS + .lock() + .unwrap() + .get(&tid) + .unwrap() + .read(&mut data); + Ok(()) +} + +pub fn wait_for_notification() -> Result<()> { + do_poll(&mut vec![], std::ptr::null_mut())?; + Ok(()) +} diff --git a/src/libos/src/net/io_multiplexing/mod.rs b/src/libos/src/net/io_multiplexing/mod.rs index df7e7fcd..3cc87ea0 100644 --- a/src/libos/src/net/io_multiplexing/mod.rs +++ b/src/libos/src/net/io_multiplexing/mod.rs @@ -1,14 +1,18 @@ use super::*; mod epoll; +mod io_event; mod poll; mod select; pub use self::epoll::{AsEpollFile, EpollCtlCmd, EpollEvent, EpollEventFlags, EpollFile}; -pub use self::poll::do_poll; +pub use self::io_event::{ + clear_notifier_status, notify_thread, wait_for_notification, IoEvent, THREAD_NOTIFIERS, +}; +pub use self::poll::{do_poll, PollEvent, PollEventFlags}; pub use self::select::{select, FdSetExt}; -use fs::{AsDevRandom, AsEvent, CreationFlags, File, FileDesc, FileRef}; +use fs::{AsDevRandom, AsEvent, CreationFlags, File, FileDesc, FileRef, PipeType}; use std::any::Any; use std::convert::TryFrom; use std::fmt; diff --git a/src/libos/src/net/io_multiplexing/poll.rs b/src/libos/src/net/io_multiplexing/poll.rs index 4cbf6c30..c4b85d18 100644 --- a/src/libos/src/net/io_multiplexing/poll.rs +++ b/src/libos/src/net/io_multiplexing/poll.rs @@ -1,76 +1,223 @@ use super::*; -pub fn do_poll(pollfds: &mut [libc::pollfd], timeout: c_int) -> Result { - debug!( - "poll: {:?}, timeout: {}", - pollfds.iter().map(|p| p.fd).collect::>(), - timeout - ); +bitflags! { + #[derive(Default)] + #[repr(C)] + pub struct PollEventFlags: i16 { + const POLLIN = 0x0001; + const POLLPRI = 0x0002; + const POLLOUT = 0x0004; + const POLLERR = 0x0008; + const POLLHUP = 0x0010; + const POLLNVAL = 0x0020; + const POLLRDNORM = 0x0040; + const POLLRDBAND = 0x0080; + const POLLWRNORM = 0x0100; + const POLLWRBAND = 0x0200; + const POLLMSG = 0x0400; + const POLLRDHUP = 0x2000; + } +} - // Untrusted pollfd's that will be modified by OCall - let mut u_pollfds: Vec = pollfds.to_vec(); +#[derive(Clone, Copy, Debug)] +#[repr(C)] +pub struct PollEvent { + fd: FileDesc, + events: PollEventFlags, + revents: PollEventFlags, +} +impl PollEvent { + pub fn new(fd: FileDesc, events: PollEventFlags) -> Self { + let revents = PollEventFlags::empty(); + Self { + fd, + events, + revents, + } + } + + pub fn fd(&self) -> FileDesc { + self.fd + } + + pub fn events(&self) -> PollEventFlags { + self.events + } + + pub fn revents(&self) -> PollEventFlags { + self.revents + } + + pub fn set_events(&mut self, events: PollEventFlags) { + self.events = events; + } + + pub fn get_revents(&mut self, events: PollEventFlags) -> bool { + self.revents = (self.events + | PollEventFlags::POLLHUP + | PollEventFlags::POLLERR + | PollEventFlags::POLLNVAL) + & events; + !self.revents.is_empty() + } +} + +pub fn do_poll(pollfds: &mut [PollEvent], timeout: *mut timeval_t) -> Result { + let mut libos_ready_num = 0; + let mut host_ready_num = 0; + let mut notified = 0; let current = current!(); + + // The pollfd of the host file + let mut host_pollfds: Vec = Vec::new(); + // The indices in pollfds of host file + let mut index_host_pollfds: Vec = Vec::new(); + // Vec: The indices in pollfds which may be more than one for the same file + // PollEvent: the merged pollfd of FileDesc + let mut libos_pollfds: HashMap)> = HashMap::new(); + for (i, pollfd) in pollfds.iter_mut().enumerate() { - // Poll should just ignore negative fds - if pollfd.fd < 0 { - u_pollfds[i].fd = -1; - u_pollfds[i].revents = 0; + // Ignore negative fds + if (pollfd.fd() as i32) < 0 { continue; } - let file_ref = current.file(pollfd.fd as FileDesc)?; - if let Ok(socket) = file_ref.as_socket() { - // convert libos fd to host fd in the copy to keep pollfds unchanged - u_pollfds[i].fd = socket.fd(); - u_pollfds[i].revents = 0; - } else if let Ok(eventfd) = file_ref.as_event() { - u_pollfds[i].fd = eventfd.get_host_fd(); - u_pollfds[i].revents = 0; - } else if let Ok(socket) = file_ref.as_unix_socket() { - // FIXME: spin poll until can read (hack for php) - while (pollfd.events & libc::POLLIN) != 0 && socket.poll()?.0 == false { - spin_loop_hint(); + let file_ref = if let Ok(file_ref) = current.file(pollfd.fd) { + file_ref + } else { + pollfd.get_revents(PollEventFlags::POLLNVAL); + continue; + }; + + if file_ref.as_unix_socket().is_ok() + || file_ref.as_pipe_reader().is_ok() + || file_ref.as_pipe_writer().is_ok() + || file_ref.as_dev_random().is_ok() + { + let events = file_ref.poll()?; + debug!("polled events are {:?}", events); + if pollfd.get_revents(events) { + libos_ready_num += 1; } - let (r, w, e) = socket.poll()?; - if r { - pollfd.revents |= libc::POLLIN; + // Merge pollfds with the same fd + if let Some((old_pollfd, index_vec)) = + libos_pollfds.insert(pollfd.fd(), (*pollfd, vec![i])) + { + let (new_pollfd, new_index_vec) = libos_pollfds.get_mut(&pollfd.fd()).unwrap(); + new_pollfd.set_events(old_pollfd.events() | new_pollfd.events()); + new_index_vec.extend_from_slice(&index_vec); } - if w { - pollfd.revents |= libc::POLLOUT; - } - pollfd.revents &= pollfd.events; - if e { - pollfd.revents |= libc::POLLERR; - } - warn!("poll unix socket is unimplemented, spin for read"); - return Ok(1); - } else if let Ok(dev_random) = file_ref.as_dev_random() { - return Ok(dev_random.poll(pollfd)?); + continue; + } + + if let Ok(socket) = file_ref.as_socket() { + let fd = socket.fd() as FileDesc; + index_host_pollfds.push(i); + host_pollfds.push(PollEvent::new(fd, pollfd.events())); + } else if let Ok(eventfd) = file_ref.as_event() { + let fd = eventfd.get_host_fd() as FileDesc; + index_host_pollfds.push(i); + host_pollfds.push(PollEvent::new(fd, pollfd.events())); } else { return_errno!(EBADF, "not a supported file type"); } } - let (u_pollfds_ptr, u_pollfds_len) = u_pollfds.as_mut_slice().as_mut_ptr_and_len(); + let notifier_host_fd = THREAD_NOTIFIERS + .lock() + .unwrap() + .get(¤t.tid()) + .unwrap() + .get_host_fd(); - let num_events = try_libc!(libc::ocall::poll( - u_pollfds_ptr, - u_pollfds_len as u64, - timeout - )) as usize; - assert!(num_events <= pollfds.len()); + debug!( + "number of ready libos fd is {}; notifier_host_fd is {}", + libos_ready_num, notifier_host_fd + ); - // Copy back revents from the untrusted pollfds - let mut num_nonzero_revents = 0; - for (i, pollfd) in pollfds.iter_mut().enumerate() { - if u_pollfds[i].revents == 0 { - continue; + let ret = if libos_ready_num != 0 { + // Clear the status of notifier before wait + clear_notifier_status(current!().tid())?; + + let mut zero_timeout: timeval_t = timeval_t::new(0, 0); + + do_poll_in_host(&mut host_pollfds, &mut zero_timeout, notifier_host_fd)? + } else { + host_pollfds.push(PollEvent::new( + notifier_host_fd as u32, + PollEventFlags::POLLIN, + )); + // Clear the status of notifier before queue + clear_notifier_status(current!().tid())?; + + for (fd, (pollfd, _)) in &libos_pollfds { + let file_ref = current.file(*fd)?; + file_ref.enqueue_event(IoEvent::Poll(*pollfd))?; + } + let ret = do_poll_in_host(&mut host_pollfds, timeout, notifier_host_fd)?; + // Pop the notifier first + if !host_pollfds.pop().unwrap().revents().is_empty() { + notified = 1; + } + // Set the return events and dequeue + for (fd, (pollfd, index_vec)) in &libos_pollfds { + let file_ref = current.file(*fd)?; + let events = file_ref.poll()?; + for i in index_vec { + if pollfds[*i].get_revents(events) { + libos_ready_num += 1; + } + } + file_ref.dequeue_event()?; + } + ret + }; + + // Copy back revents for host pollfd + for (i, pollfd) in host_pollfds.iter().enumerate() { + if pollfds[index_host_pollfds[i]].get_revents(pollfd.revents()) { + host_ready_num += 1; } - pollfd.revents = u_pollfds[i].revents; - num_nonzero_revents += 1; } - assert!(num_nonzero_revents == num_events); - Ok(num_events as usize) + + assert!(ret == host_ready_num + notified); + debug!("pollfds returns {:?}", pollfds); + Ok(host_ready_num + libos_ready_num) +} + +fn do_poll_in_host( + mut host_pollfds: &mut [PollEvent], + timeout: *mut timeval_t, + notifier_host_fd: c_int, +) -> Result { + let (host_pollfds_ptr, host_pollfds_len) = host_pollfds.as_mut_ptr_and_len(); + + let ret = try_libc!({ + let mut retval: c_int = 0; + let status = occlum_ocall_poll( + &mut retval, + host_pollfds_ptr as *mut _, + host_pollfds_len as u64, + timeout, + notifier_host_fd, + ); + assert!(status == sgx_status_t::SGX_SUCCESS); + + retval + }) as usize; + + assert!(ret <= host_pollfds.len()); + Ok(ret) +} + +extern "C" { + fn occlum_ocall_poll( + ret: *mut c_int, + fds: *mut PollEvent, + nfds: u64, + timeout: *mut timeval_t, + eventfd: c_int, + ) -> sgx_status_t; } diff --git a/src/libos/src/net/io_multiplexing/select.rs b/src/libos/src/net/io_multiplexing/select.rs index c7920681..ba8fe14d 100644 --- a/src/libos/src/net/io_multiplexing/select.rs +++ b/src/libos/src/net/io_multiplexing/select.rs @@ -5,18 +5,17 @@ pub fn select( readfds: &mut libc::fd_set, writefds: &mut libc::fd_set, exceptfds: &mut libc::fd_set, - timeout: Option<&mut timeval_t>, + timeout: *mut timeval_t, ) -> Result { - debug!("select: nfds: {} timeout: {:?}", nfds, timeout); + debug!( + "read: {} write: {} exception: {}", + readfds.format(), + writefds.format(), + exceptfds.format() + ); - let current = current!(); - let file_table = current.files().lock().unwrap(); - - let mut max_host_fd = None; - let mut host_to_libos_fd = [None; libc::FD_SETSIZE]; - let mut unsafe_readfds = libc::fd_set::new_empty(); - let mut unsafe_writefds = libc::fd_set::new_empty(); - let mut unsafe_exceptfds = libc::fd_set::new_empty(); + let mut ready_num = 0; + let mut pollfds: Vec = Vec::new(); for fd in 0..(nfds as FileDesc) { let (r, w, e) = ( @@ -28,165 +27,63 @@ pub fn select( continue; } - let fd_ref = file_table.get(fd)?; - - if let Ok(socket) = fd_ref.as_unix_socket() { - warn!("select unix socket is unimplemented, spin for read"); - readfds.clear(); - writefds.clear(); - exceptfds.clear(); - - // FIXME: spin poll until can read (hack for php) - while r && socket.poll()?.0 == false { - spin_loop_hint(); - } - - let (rr, ww, ee) = socket.poll()?; - let mut ready_num = 0; - if r && rr { - readfds.set(fd)?; - ready_num += 1; - } - if w && ww { - writefds.set(fd)?; - ready_num += 1; - } - if e && ee { - exceptfds.set(fd)?; - ready_num += 1; - } - return Ok(ready_num); + if current!().file(fd).is_err() { + return_errno!( + EBADF, + "An invalid file descriptor was given in one of the sets" + ); } - let host_fd = if let Ok(socket) = fd_ref.as_socket() { - socket.fd() - } else if let Ok(eventfd) = fd_ref.as_event() { - eventfd.get_host_fd() - } else { - return_errno!(EBADF, "unsupported file type"); - } as FileDesc; - - if host_fd as usize >= libc::FD_SETSIZE { - return_errno!(EBADF, "host fd exceeds FD_SETSIZE"); - } - - // convert libos fd to host fd - host_to_libos_fd[host_fd as usize] = Some(fd); - max_host_fd = Some(max(max_host_fd.unwrap_or(0), host_fd as c_int)); + let mut events = PollEventFlags::empty(); if r { - unsafe_readfds.set(host_fd)?; + events |= PollEventFlags::POLLIN; } if w { - unsafe_writefds.set(host_fd)?; + events |= PollEventFlags::POLLOUT; } if e { - unsafe_exceptfds.set(host_fd)?; + events |= PollEventFlags::POLLPRI; } + + pollfds.push(PollEvent::new(fd, events)); } - // Unlock the file table as early as possible - drop(file_table); - - let host_nfds = if let Some(fd) = max_host_fd { - fd + 1 + let mut origin_timeout: timeval_t = if timeout.is_null() { + Default::default() } else { - // Set nfds to zero if no fd is monitored - 0 + unsafe { *timeout } }; - let ret = do_select_in_host( - host_nfds, - &mut unsafe_readfds, - &mut unsafe_writefds, - &mut unsafe_exceptfds, - timeout, - )?; + let ret = do_poll(&mut pollfds, timeout)?; - // convert fd back and write fdset and do ocall check - let mut ready_num = 0; - for host_fd in 0..host_nfds as FileDesc { - let fd_option = host_to_libos_fd[host_fd as usize]; - let (r, w, e) = ( - unsafe_readfds.is_set(host_fd), - unsafe_writefds.is_set(host_fd), - unsafe_exceptfds.is_set(host_fd), - ); - if !(r || w || e) { - if let Some(fd) = fd_option { - readfds.unset(fd)?; - writefds.unset(fd)?; - exceptfds.unset(fd)?; - } - continue; - } + readfds.clear(); + writefds.clear(); + exceptfds.clear(); - let fd = fd_option.expect("host_fd with events must have a responding libos fd"); - - if r { - assert!(readfds.is_set(fd)); - ready_num += 1; - } else { - readfds.unset(fd)?; - } - if w { - assert!(writefds.is_set(fd)); - ready_num += 1; - } else { - writefds.unset(fd)?; - } - if e { - assert!(exceptfds.is_set(fd)); - ready_num += 1; - } else { - exceptfds.unset(fd)?; - } - } - - assert!(ready_num == ret); - Ok(ret) -} - -fn do_select_in_host( - host_nfds: c_int, - readfds: &mut libc::fd_set, - writefds: &mut libc::fd_set, - exceptfds: &mut libc::fd_set, - timeout: Option<&mut timeval_t>, -) -> Result { - let readfds_ptr = readfds.as_raw_ptr_mut(); - let writefds_ptr = writefds.as_raw_ptr_mut(); - let exceptfds_ptr = exceptfds.as_raw_ptr_mut(); - - let mut origin_timeout: timeval_t = Default::default(); - let timeout_ptr = if let Some(to) = timeout { - origin_timeout = *to; - to - } else { - std::ptr::null_mut() - } as *mut timeval_t; - - let ret = try_libc!({ - let mut retval: c_int = 0; - let status = occlum_ocall_select( - &mut retval, - host_nfds, - readfds_ptr, - writefds_ptr, - exceptfds_ptr, - timeout_ptr, - ); - assert!(status == sgx_status_t::SGX_SUCCESS); - - retval - }) as isize; - - if !timeout_ptr.is_null() { - let time_left = unsafe { *(timeout_ptr) }; + if !timeout.is_null() { + let time_left = unsafe { *(timeout) }; time_left.validate()?; assert!(time_left.as_duration() <= origin_timeout.as_duration()); } - Ok(ret) + debug!("returned pollfds are {:?}", pollfds); + for pollfd in &pollfds { + let (r_poll, w_poll, e_poll) = convert_to_readable_writable_exceptional(pollfd.revents()); + if r_poll { + readfds.set(pollfd.fd())?; + ready_num += 1; + } + if w_poll { + writefds.set(pollfd.fd())?; + ready_num += 1; + } + if e_poll { + exceptfds.set(pollfd.fd())?; + ready_num += 1; + } + } + + Ok(ready_num) } /// Safe methods for `libc::fd_set` @@ -198,6 +95,7 @@ pub trait FdSetExt { fn clear(&mut self); fn is_empty(&self) -> bool; fn as_raw_ptr_mut(&mut self) -> *mut Self; + fn format(&self) -> String; } impl FdSetExt for libc::fd_set { @@ -252,15 +150,29 @@ impl FdSetExt for libc::fd_set { self as *mut libc::fd_set } } + + fn format(&self) -> String { + let set = unsafe { + std::slice::from_raw_parts(self as *const Self as *const u64, libc::FD_SETSIZE / 64) + }; + format!("libc::fd_set: {:x?}", set) + } } -extern "C" { - fn occlum_ocall_select( - ret: *mut c_int, - nfds: c_int, - readfds: *mut libc::fd_set, - writefds: *mut libc::fd_set, - exceptfds: *mut libc::fd_set, - timeout: *mut timeval_t, - ) -> sgx_status_t; +// The correspondence is from man2/select.2.html +fn convert_to_readable_writable_exceptional(events: PollEventFlags) -> (bool, bool, bool) { + ( + (PollEventFlags::POLLRDNORM + | PollEventFlags::POLLRDBAND + | PollEventFlags::POLLIN + | PollEventFlags::POLLHUP + | PollEventFlags::POLLERR) + .intersects(events), + (PollEventFlags::POLLWRBAND + | PollEventFlags::POLLWRNORM + | PollEventFlags::POLLOUT + | PollEventFlags::POLLERR) + .intersects(events), + PollEventFlags::POLLPRI.intersects(events), + ) } diff --git a/src/libos/src/net/mod.rs b/src/libos/src/net/mod.rs index 71f9270e..56a84726 100644 --- a/src/libos/src/net/mod.rs +++ b/src/libos/src/net/mod.rs @@ -10,7 +10,10 @@ mod socket_file; mod syscalls; mod unix_socket; -pub use self::io_multiplexing::EpollEvent; +pub use self::io_multiplexing::{ + clear_notifier_status, notify_thread, wait_for_notification, EpollEvent, IoEvent, PollEvent, + PollEventFlags, THREAD_NOTIFIERS, +}; pub use self::iovs::{Iovs, IovsMut, SliceAsLibcIovec}; pub use self::msg::{msghdr, msghdr_mut, MsgHdr, MsgHdrMut}; pub use self::msg_flags::{MsgHdrFlags, RecvFlags, SendFlags}; diff --git a/src/libos/src/net/syscalls.rs b/src/libos/src/net/syscalls.rs index 11786739..fa91e513 100644 --- a/src/libos/src/net/syscalls.rs +++ b/src/libos/src/net/syscalls.rs @@ -507,6 +507,13 @@ pub fn do_select( ); } + if !timeout.is_null() { + from_user::check_ptr(timeout)?; + unsafe { + (*timeout).validate()?; + } + } + // Select handles empty set and null in the same way // TODO: Elegently handle the empty fd_set without allocating redundant fd_set let mut empty_set_for_read = libc::fd_set::new_empty(); @@ -532,21 +539,11 @@ pub fn do_select( &mut empty_set_for_except }; - let timeout_option = if !timeout.is_null() { - from_user::check_ptr(timeout)?; - unsafe { - (*timeout).validate()?; - Some(&mut *timeout) - } - } else { - None - }; - - let ret = io_multiplexing::select(nfds, readfds, writefds, exceptfds, timeout_option)?; + let ret = io_multiplexing::select(nfds, readfds, writefds, exceptfds, timeout)?; Ok(ret) } -pub fn do_poll(fds: *mut libc::pollfd, nfds: libc::nfds_t, timeout: c_int) -> Result { +pub fn do_poll(fds: *mut PollEvent, nfds: libc::nfds_t, timeout: c_int) -> Result { // It behaves like sleep when fds is null and nfds is zero. if !fds.is_null() || nfds != 0 { from_user::check_mut_array(fds, nfds as usize)?; @@ -564,8 +561,19 @@ pub fn do_poll(fds: *mut libc::pollfd, nfds: libc::nfds_t, timeout: c_int) -> Re } let polls = unsafe { std::slice::from_raw_parts_mut(fds, nfds as usize) }; + debug!("poll: {:?}, timeout: {}", polls, timeout); - let n = io_multiplexing::do_poll(polls, timeout)?; + let mut time_val = timeval_t::new( + ((timeout as u32) / 1000) as i64, + ((timeout as u32) % 1000 * 1000) as i64, + ); + let tmp_to = if timeout == -1 { + std::ptr::null_mut() + } else { + &mut time_val + }; + + let n = io_multiplexing::do_poll(polls, tmp_to)?; Ok(n as isize) } diff --git a/src/libos/src/net/unix_socket.rs b/src/libos/src/net/unix_socket.rs index 8d3e8a8d..6838dddb 100644 --- a/src/libos/src/net/unix_socket.rs +++ b/src/libos/src/net/unix_socket.rs @@ -6,12 +6,13 @@ use std::collections::btree_map::BTreeMap; use std::fmt; use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; use std::sync::SgxMutex as Mutex; -use util::ring_buf::{RingBuf, RingBufReader, RingBufWriter}; +use util::ring_buf::{ring_buffer, RingBufReader, RingBufWriter}; pub struct UnixSocketFile { inner: Mutex, } +// TODO: add enqueue_event and dequeue_event impl File for UnixSocketFile { fn read(&self, buf: &mut [u8]) -> Result { let mut inner = self.inner.lock().unwrap(); @@ -33,32 +34,12 @@ impl File for UnixSocketFile { fn readv(&self, bufs: &mut [&mut [u8]]) -> Result { let mut inner = self.inner.lock().unwrap(); - let mut total_len = 0; - for buf in bufs { - match inner.read(buf) { - Ok(len) => { - total_len += len; - } - Err(_) if total_len != 0 => break, - Err(e) => return Err(e.into()), - } - } - Ok(total_len) + inner.readv(bufs) } fn writev(&self, bufs: &[&[u8]]) -> Result { let mut inner = self.inner.lock().unwrap(); - let mut total_len = 0; - for buf in bufs { - match inner.write(buf) { - Ok(len) => { - total_len += len; - } - Err(_) if total_len != 0 => break, - Err(e) => return Err(e.into()), - } - } - Ok(total_len) + inner.writev(bufs) } fn metadata(&self) -> Result { @@ -85,6 +66,11 @@ impl File for UnixSocketFile { inner.ioctl(cmd) } + fn poll(&self) -> Result { + let mut inner = self.inner.lock().unwrap(); + inner.poll() + } + fn as_any(&self) -> &dyn Any { self } @@ -124,11 +110,6 @@ impl UnixSocketFile { inner.connect(path) } - pub fn poll(&self) -> Result<(bool, bool, bool)> { - let mut inner = self.inner.lock().unwrap(); - inner.poll() - } - pub fn socketpair(socket_type: i32, protocol: i32) -> Result<(Self, Self)> { let listen_socket = Self::new(socket_type, protocol)?; let bound_path = listen_socket.bind_until_success(); @@ -252,20 +233,48 @@ impl UnixSocket { Ok(()) } - pub fn read(&self, buf: &mut [u8]) -> Result { - self.channel()?.reader.read(buf) + pub fn read(&mut self, buf: &mut [u8]) -> Result { + self.channel_mut()?.reader.read_from_buffer(buf) } - pub fn write(&self, buf: &[u8]) -> Result { - self.channel()?.writer.write(buf) + pub fn readv(&mut self, bufs: &mut [&mut [u8]]) -> Result { + self.channel_mut()?.reader.read_from_vector(bufs) } - pub fn poll(&self) -> Result<(bool, bool, bool)> { - // (read, write, error) - let channel = self.channel()?; - let r = channel.reader.can_read(); - let w = channel.writer.can_write(); - Ok((r, w, false)) + pub fn write(&mut self, buf: &[u8]) -> Result { + self.channel_mut()?.writer.write_to_buffer(buf) + } + + pub fn writev(&mut self, bufs: &[&[u8]]) -> Result { + self.channel_mut()?.writer.write_to_vector(bufs) + } + + fn poll(&self) -> Result { + let channel_result = self.channel(); + if let Ok(channel) = channel_result { + let readable = channel.reader.can_read() && !channel.reader.is_peer_closed(); + let writable = channel.writer.can_write() && !channel.writer.is_peer_closed(); + let events = if readable ^ writable { + if channel.reader.can_read() { + PollEventFlags::POLLRDHUP | PollEventFlags::POLLIN | PollEventFlags::POLLRDNORM + } else { + PollEventFlags::POLLRDHUP + } + // both readable and writable + } else if readable { + PollEventFlags::POLLIN + | PollEventFlags::POLLOUT + | PollEventFlags::POLLRDNORM + | PollEventFlags::POLLWRNORM + } else { + PollEventFlags::POLLHUP + }; + Ok(events) + } else { + // For the unconnected socket + // TODO: add write support for unconnected sockets like linux does + Ok(PollEventFlags::POLLHUP) + } } pub fn ioctl(&self, cmd: &mut IoctlCmd) -> Result { @@ -283,6 +292,14 @@ impl UnixSocket { Ok(0) } + fn channel_mut(&mut self) -> Result<&mut Channel> { + if let Status::Connected(ref mut channel) = &mut self.status { + Ok(channel) + } else { + return_errno!(EBADF, "UnixSocket is not connected") + } + } + fn channel(&self) -> Result<&Channel> { if let Status::Connected(channel) = &self.status { Ok(channel) @@ -349,15 +366,15 @@ unsafe impl Sync for Channel {} impl Channel { fn new_pair() -> Result<(Channel, Channel)> { - let buf1 = RingBuf::new(DEFAULT_BUF_SIZE)?; - let buf2 = RingBuf::new(DEFAULT_BUF_SIZE)?; + let (reader1, writer1) = ring_buffer(DEFAULT_BUF_SIZE)?; + let (reader2, writer2) = ring_buffer(DEFAULT_BUF_SIZE)?; let channel1 = Channel { - reader: buf1.reader, - writer: buf2.writer, + reader: reader1, + writer: writer2, }; let channel2 = Channel { - reader: buf2.reader, - writer: buf1.writer, + reader: reader2, + writer: writer1, }; Ok((channel1, channel2)) } diff --git a/src/libos/src/process/thread/mod.rs b/src/libos/src/process/thread/mod.rs index 7f88c7da..f51a2ebf 100644 --- a/src/libos/src/process/thread/mod.rs +++ b/src/libos/src/process/thread/mod.rs @@ -6,6 +6,8 @@ use super::{ FileTableRef, ForcedExitStatus, FsViewRef, ProcessRef, ProcessVM, ProcessVMRef, ResourceLimitsRef, SchedAgentRef, TermStatus, ThreadRef, }; +use crate::fs::{EventCreationFlags, EventFile}; +use crate::net::THREAD_NOTIFIERS; use crate::prelude::*; use crate::signal::{SigQueues, SigSet, SigStack}; use crate::time::ThreadProfiler; @@ -146,6 +148,18 @@ impl Thread { self.sched().lock().unwrap().attach(host_tid); self.inner().start(); + let eventfd = EventFile::new( + 0, + EventCreationFlags::EFD_CLOEXEC | EventCreationFlags::EFD_NONBLOCK, + ) + .unwrap(); + + THREAD_NOTIFIERS + .lock() + .unwrap() + .insert(self.tid(), eventfd) + .expect_none("this thread should not have an eventfd before start"); + #[cfg(feature = "syscall_timing")] self.profiler() .lock() @@ -166,6 +180,12 @@ impl Thread { .stop() .unwrap(); + THREAD_NOTIFIERS + .lock() + .unwrap() + .remove(&self.tid()) + .unwrap(); + self.sched().lock().unwrap().detach(); // Remove this thread from its owner process diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index d4de72e9..f6dff843 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -32,8 +32,8 @@ use crate::net::{ do_accept, do_accept4, do_bind, do_connect, do_epoll_create, do_epoll_create1, do_epoll_ctl, do_epoll_pwait, do_epoll_wait, do_getpeername, do_getsockname, do_getsockopt, do_listen, do_poll, do_recvfrom, do_recvmsg, do_select, do_sendmsg, do_sendto, do_setsockopt, do_shutdown, - do_socket, do_socketpair, msghdr, msghdr_mut, AsSocket, AsUnixSocket, EpollEvent, SocketFile, - UnixSocketFile, + do_socket, do_socketpair, msghdr, msghdr_mut, AsSocket, AsUnixSocket, EpollEvent, PollEvent, + SocketFile, UnixSocketFile, }; use crate::process::{ do_arch_prctl, do_clone, do_exit, do_exit_group, do_futex, do_getegid, do_geteuid, do_getgid, @@ -86,7 +86,7 @@ macro_rules! process_syscall_table_with_callback { (Stat = 4) => do_stat(path: *const i8, stat_buf: *mut Stat), (Fstat = 5) => do_fstat(fd: FileDesc, stat_buf: *mut Stat), (Lstat = 6) => do_lstat(path: *const i8, stat_buf: *mut Stat), - (Poll = 7) => do_poll(fds: *mut libc::pollfd, nfds: libc::nfds_t, timeout: c_int), + (Poll = 7) => do_poll(fds: *mut PollEvent, nfds: libc::nfds_t, timeout: c_int), (Lseek = 8) => do_lseek(fd: FileDesc, offset: off_t, whence: i32), (Mmap = 9) => do_mmap(addr: usize, size: usize, perms: i32, flags: i32, fd: FileDesc, offset: off_t), (Mprotect = 10) => do_mprotect(addr: usize, len: usize, prot: u32), @@ -369,7 +369,7 @@ macro_rules! process_syscall_table_with_callback { (TimerfdGettime = 287) => handle_unsupported(), (Accept4 = 288) => do_accept4(fd: c_int, addr: *mut libc::sockaddr, addr_len: *mut libc::socklen_t, flags: c_int), (Signalfd4 = 289) => handle_unsupported(), - (Eventfd2 = 290) => do_eventfd2(init_val: u32, flaggs: i32), + (Eventfd2 = 290) => do_eventfd2(init_val: u32, flags: i32), (EpollCreate1 = 291) => do_epoll_create1(flags: c_int), (Dup3 = 292) => do_dup3(old_fd: FileDesc, new_fd: FileDesc, flags: u32), (Pipe2 = 293) => do_pipe2(fds_u: *mut i32, flags: u32), diff --git a/src/libos/src/time/mod.rs b/src/libos/src/time/mod.rs index 92e815a9..68a20703 100644 --- a/src/libos/src/time/mod.rs +++ b/src/libos/src/time/mod.rs @@ -30,6 +30,13 @@ pub struct timeval_t { } impl timeval_t { + pub fn new(sec: time_t, usec: suseconds_t) -> Self { + let time = Self { sec, usec }; + + time.validate().unwrap(); + time + } + pub fn validate(&self) -> Result<()> { if self.sec >= 0 && self.usec >= 0 && self.usec < 1_000_000 { Ok(()) diff --git a/src/libos/src/util/ring_buf.rs b/src/libos/src/util/ring_buf.rs index dd7475eb..2e4bbec2 100644 --- a/src/libos/src/util/ring_buf.rs +++ b/src/libos/src/util/ring_buf.rs @@ -1,232 +1,428 @@ use alloc::alloc::{alloc, dealloc, Layout}; +use crate::net::{ + clear_notifier_status, notify_thread, wait_for_notification, IoEvent, PollEventFlags, +}; use std::cmp::{max, min}; use std::ptr; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use super::*; +use ringbuf::{Consumer, Producer, RingBuffer}; -#[derive(Debug)] -pub struct RingBuf { - pub reader: RingBufReader, - pub writer: RingBufWriter, +pub fn ring_buffer(capacity: usize) -> Result<(RingBufReader, RingBufWriter)> { + let meta = RingBufMeta::new(); + let buffer = RingBuffer::::new(capacity); + let (producer, consumer) = buffer.split(); + let meta_ref = Arc::new(meta); + + let reader = RingBufReader { + inner: consumer, + buffer: meta_ref.clone(), + }; + let writer = RingBufWriter { + inner: producer, + buffer: meta_ref, + }; + Ok((reader, writer)) } -impl RingBuf { - pub fn new(capacity: usize) -> Result { - let inner = Arc::new(RingBufInner::new(capacity)?); - let reader = RingBufReader { - inner: inner.clone(), - }; - let writer = RingBufWriter { inner: inner }; - Ok(RingBuf { - reader: reader, - writer: writer, - }) +struct RingBufMeta { + lock: Arc>, // lock for the synchronization of reader and writer + reader_closed: AtomicBool, // if reader has been dropped + writer_closed: AtomicBool, // if writer has been dropped + reader_wait_queue: SgxMutex>, + writer_wait_queue: SgxMutex>, + // TODO: support O_ASYNC and O_DIRECT in ringbuffer + blocking_read: AtomicBool, // if the read is blocking + blocking_write: AtomicBool, // if the write is blocking +} + +impl RingBufMeta { + pub fn new() -> RingBufMeta { + Self { + lock: Arc::new(SgxMutex::new(true)), + reader_closed: AtomicBool::new(false), + writer_closed: AtomicBool::new(false), + reader_wait_queue: SgxMutex::new(HashMap::new()), + writer_wait_queue: SgxMutex::new(HashMap::new()), + blocking_read: AtomicBool::new(true), + blocking_write: AtomicBool::new(true), + } + } + + pub fn is_reader_closed(&self) -> bool { + self.reader_closed.load(Ordering::SeqCst) + } + + pub fn close_reader(&self) { + self.reader_closed.store(true, Ordering::SeqCst); + } + + pub fn is_writer_closed(&self) -> bool { + self.writer_closed.load(Ordering::SeqCst) + } + + pub fn close_writer(&self) { + self.writer_closed.store(true, Ordering::SeqCst); + } + + pub fn reader_wait_queue(&self) -> &SgxMutex> { + &self.reader_wait_queue + } + + pub fn writer_wait_queue(&self) -> &SgxMutex> { + &self.writer_wait_queue + } + + pub fn enqueue_reader_event(&self, event: IoEvent) -> Result<()> { + self.reader_wait_queue + .lock() + .unwrap() + .insert(current!().tid(), event); + Ok(()) + } + + pub fn dequeue_reader_event(&self) -> Result<()> { + self.reader_wait_queue + .lock() + .unwrap() + .remove(¤t!().tid()) + .unwrap(); + Ok(()) + } + + pub fn enqueue_writer_event(&self, event: IoEvent) -> Result<()> { + self.writer_wait_queue + .lock() + .unwrap() + .insert(current!().tid(), event); + Ok(()) + } + + pub fn dequeue_writer_event(&self) -> Result<()> { + self.writer_wait_queue + .lock() + .unwrap() + .remove(¤t!().tid()) + .unwrap(); + Ok(()) + } + + pub fn blocking_read(&self) -> bool { + self.blocking_read.load(Ordering::SeqCst) + } + + pub fn set_non_blocking_read(&self) { + self.blocking_read.store(false, Ordering::SeqCst); + } + + pub fn set_blocking_read(&self) { + self.blocking_read.store(true, Ordering::SeqCst); + } + + pub fn blocking_write(&self) -> bool { + self.blocking_write.load(Ordering::SeqCst) + } + + pub fn set_non_blocking_write(&self) { + self.blocking_write.store(false, Ordering::SeqCst); + } + + pub fn set_blocking_write(&self) { + self.blocking_write.store(true, Ordering::SeqCst); } } -#[derive(Debug)] pub struct RingBufReader { - inner: Arc, -} - -#[derive(Debug)] -pub struct RingBufWriter { - inner: Arc, -} - -#[derive(Debug)] -struct RingBufInner { - buf: *mut u8, - capacity: usize, - head: AtomicUsize, // write to head - tail: AtomicUsize, // read from tail - closed: AtomicBool, // if reader has been dropped -} - -const RING_BUF_ALIGN: usize = 16; - -impl RingBufInner { - fn new(capacity: usize) -> Result { - // Capacity should be power of two as capacity - 1 is used as mask - let capacity = max(capacity, RING_BUF_ALIGN).next_power_of_two(); - let buf_layout = Layout::from_size_align(capacity, RING_BUF_ALIGN)?; - let buf_ptr = unsafe { alloc(buf_layout) }; - if buf_ptr.is_null() { - return_errno!(ENOMEM, "no memory for new ring buffers"); - } - - Ok(RingBufInner { - buf: buf_ptr, - capacity: capacity, - head: AtomicUsize::new(0), - tail: AtomicUsize::new(0), - closed: AtomicBool::new(false), - }) - } - - fn get_mask(&self) -> usize { - self.capacity - 1 // Note that capacity is a power of two - } - - fn get_head(&self) -> usize { - self.head.load(Ordering::SeqCst) - } - - fn get_tail(&self) -> usize { - self.tail.load(Ordering::SeqCst) - } - - fn set_head(&self, new_head: usize) { - self.head.store(new_head, Ordering::SeqCst) - } - - fn set_tail(&self, new_tail: usize) { - self.tail.store(new_tail, Ordering::SeqCst) - } - - fn is_closed(&self) -> bool { - self.closed.load(Ordering::SeqCst) - } - - fn close(&self) { - self.closed.store(true, Ordering::SeqCst); - } - - unsafe fn read_at(&self, pos: usize, dst_buf: &mut [u8]) { - let dst_ptr = dst_buf.as_mut_ptr(); - let dst_len = dst_buf.len(); - let src_ptr = self.buf.offset(pos as isize); - unsafe { - src_ptr.copy_to_nonoverlapping(dst_ptr, dst_len); - } - } - - unsafe fn write_at(&self, pos: usize, src_buf: &[u8]) { - let src_ptr = src_buf.as_ptr(); - let src_len = src_buf.len(); - let dst_ptr = self.buf.offset(pos as isize); - unsafe { - dst_ptr.copy_from_nonoverlapping(src_ptr, src_len); - } - } -} - -impl Drop for RingBufInner { - fn drop(&mut self) { - let buf_layout = Layout::from_size_align(self.capacity, RING_BUF_ALIGN).unwrap(); - unsafe { - dealloc(self.buf, buf_layout); - } - } + inner: Consumer, + buffer: Arc, } impl RingBufReader { - pub fn read(&self, buf: &mut [u8]) -> Result { - let mut tail = self.inner.get_tail(); - let mut buf_remain = buf.len(); - let mut buf_pos = 0; - while buf_remain > 0 { - let head = self.inner.get_head(); - - let read_nbytes = { - let may_read_nbytes = if tail <= head { - head - tail - } else { - self.inner.capacity - tail - }; - if may_read_nbytes == 0 { - break; - } - - min(may_read_nbytes, buf_remain) - }; - - let dst_buf = &mut buf[buf_pos..(buf_pos + read_nbytes)]; - unsafe { - self.inner.read_at(tail, dst_buf); - } - - tail = (tail + read_nbytes) & self.inner.get_mask(); - self.inner.set_tail(tail); - - buf_pos += read_nbytes; - buf_remain -= read_nbytes; - } - Ok(buf_pos) - } - pub fn can_read(&self) -> bool { self.bytes_to_read() != 0 } - pub fn bytes_to_read(&self) -> usize { - let tail = self.inner.get_tail(); - let head = self.inner.get_head(); - if tail <= head { - head - tail + pub fn read_from_buffer(&mut self, buffer: &mut [u8]) -> Result { + self.read(Some(buffer), None) + } + + pub fn read_from_vector(&mut self, buffers: &mut [&mut [u8]]) -> Result { + self.read(None, Some(buffers)) + } + + fn read( + &mut self, + buffer: Option<&mut [u8]>, + buffers: Option<&mut [&mut [u8]]>, + ) -> Result { + assert!(buffer.is_some() ^ buffers.is_some()); + // In case of write after can_read is false + let lock_ref = self.buffer.lock.clone(); + let lock_holder = lock_ref.lock(); + + if self.can_read() { + let count = if buffer.is_some() { + self.inner.pop_slice(buffer.unwrap()) + } else { + self.pop_slices(buffers.unwrap()) + }; + assert!(count > 0); + self.read_end(); + Ok(count) } else { - self.inner.capacity - tail + head + if self.is_peer_closed() { + return Ok(0); + } + + if !self.buffer.blocking_read() { + return_errno!(EAGAIN, "No data to read"); + } else { + // Clear the status of notifier before enqueue + clear_notifier_status(current!().tid())?; + self.enqueue_event(IoEvent::BlockingRead)?; + drop(lock_holder); + drop(lock_ref); + let ret = wait_for_notification(); + self.dequeue_event()?; + ret?; + + let lock_ref = self.buffer.lock.clone(); + let lock_holder = lock_ref.lock(); + let count = if buffer.is_some() { + self.inner.pop_slice(buffer.unwrap()) + } else { + self.pop_slices(buffers.unwrap()) + }; + + if count > 0 { + self.read_end()?; + } else { + assert!(self.is_peer_closed()); + } + Ok(count) + } + } + } + + fn pop_slices(&mut self, buffers: &mut [&mut [u8]]) -> usize { + let mut total = 0; + for buf in buffers { + let count = self.inner.pop_slice(buf); + total += count; + if count < buf.len() { + break; + } + } + total + } + + pub fn bytes_to_read(&self) -> usize { + self.inner.len() + } + + fn read_end(&self) -> Result<()> { + for (tid, event) in &*self.buffer.writer_wait_queue().lock().unwrap() { + match event { + IoEvent::Poll(poll_events) => { + if !(poll_events.events() + & (PollEventFlags::POLLOUT | PollEventFlags::POLLWRNORM)) + .is_empty() + { + notify_thread(*tid)?; + } + } + IoEvent::Epoll(epoll_file) => unimplemented!(), + IoEvent::BlockingRead => unreachable!(), + IoEvent::BlockingWrite => notify_thread(*tid)?, + } + } + Ok(()) + } + + pub fn is_peer_closed(&self) -> bool { + self.buffer.is_writer_closed() + } + + pub fn enqueue_event(&self, event: IoEvent) -> Result<()> { + self.buffer.enqueue_reader_event(event) + } + + pub fn dequeue_event(&self) -> Result<()> { + self.buffer.dequeue_reader_event() + } + + pub fn set_non_blocking(&self) { + self.buffer.set_non_blocking_read() + } + + pub fn set_blocking(&self) { + self.buffer.set_blocking_read() + } + + fn before_drop(&self) { + for (tid, event) in &*self.buffer.writer_wait_queue().lock().unwrap() { + match event { + IoEvent::Poll(_) | IoEvent::BlockingWrite => notify_thread(*tid).unwrap(), + IoEvent::Epoll(epoll_file) => unimplemented!(), + IoEvent::BlockingRead => unreachable!(), + } } } } impl Drop for RingBufReader { fn drop(&mut self) { - // So the writer knows when a reader is finished - self.inner.close(); + debug!("reader drop"); + self.buffer.close_reader(); + if self.buffer.blocking_write() { + self.before_drop(); + } } } +pub struct RingBufWriter { + inner: Producer, + buffer: Arc, +} + impl RingBufWriter { - pub fn write(&self, buf: &[u8]) -> Result { - if self.inner.is_closed() { - return_errno!(EPIPE, "Reader has been closed"); + pub fn write_to_buffer(&mut self, buffer: &[u8]) -> Result { + self.write(Some(buffer), None) + } + + pub fn write_to_vector(&mut self, buffers: &[&[u8]]) -> Result { + self.write(None, Some(buffers)) + } + + fn write(&mut self, buffer: Option<&[u8]>, buffers: Option<&[&[u8]]>) -> Result { + assert!(buffer.is_some() ^ buffers.is_some()); + + // TODO: send SIGPIPE to the caller + if self.is_peer_closed() { + return_errno!(EPIPE, "reader side is closed"); } - let mut head = self.inner.get_head(); - let mut buf_remain = buf.len(); - let mut buf_pos = 0; - while buf_remain > 0 { - let tail = self.inner.get_tail(); + // In case of read after can_write is false + let lock_ref = self.buffer.lock.clone(); + let lock_holder = lock_ref.lock(); - let write_nbytes = { - let may_write_nbytes = if tail <= head { - self.inner.capacity - head - } else { - tail - head - 1 - }; - if may_write_nbytes == 0 { - break; - } - - min(may_write_nbytes, buf_remain) + if self.can_write() { + let count = if buffer.is_some() { + self.inner.push_slice(buffer.unwrap()) + } else { + self.push_slices(buffers.unwrap()) }; - - let src_buf = &buf[buf_pos..(buf_pos + write_nbytes)]; - unsafe { - self.inner.write_at(head, src_buf); + assert!(count > 0); + self.write_end(); + Ok(count) + } else { + if !self.buffer.blocking_write() { + return_errno!(EAGAIN, "No space to write"); } - head = (head + write_nbytes) & self.inner.get_mask(); - self.inner.set_head(head); + // Clear the status of notifier before enqueue + clear_notifier_status(current!().tid()); + self.enqueue_event(IoEvent::BlockingWrite)?; + drop(lock_holder); + drop(lock_ref); + let ret = wait_for_notification(); + self.dequeue_event()?; + ret?; - buf_pos += write_nbytes; - buf_remain -= write_nbytes; + let lock_ref = self.buffer.lock.clone(); + let lock_holder = lock_ref.lock(); + let count = if buffer.is_some() { + self.inner.push_slice(buffer.unwrap()) + } else { + self.push_slices(buffers.unwrap()) + }; + + if count > 0 { + self.write_end(); + Ok(count) + } else { + return_errno!(EPIPE, "reader side is closed"); + } } - Ok(buf_pos) + } + + fn write_end(&self) -> Result<()> { + for (tid, event) in &*self.buffer.reader_wait_queue().lock().unwrap() { + match event { + IoEvent::Poll(poll_events) => { + if !(poll_events.events() + & (PollEventFlags::POLLIN | PollEventFlags::POLLRDNORM)) + .is_empty() + { + notify_thread(*tid)?; + } + } + IoEvent::Epoll(epoll_file) => unimplemented!(), + IoEvent::BlockingRead => notify_thread(*tid)?, + IoEvent::BlockingWrite => unreachable!(), + } + } + Ok(()) + } + + fn push_slices(&mut self, buffers: &[&[u8]]) -> usize { + let mut total = 0; + for buf in buffers { + let count = self.inner.push_slice(buf); + total += count; + if count < buf.len() { + break; + } + } + total } pub fn can_write(&self) -> bool { - let tail = self.inner.get_tail(); - let head = self.inner.get_head(); - let may_write_nbytes = if tail <= head { - self.inner.capacity - head - } else { - tail - head - 1 - }; - may_write_nbytes != 0 + !self.inner.is_full() + } + + pub fn is_peer_closed(&self) -> bool { + self.buffer.is_reader_closed() + } + + pub fn enqueue_event(&self, event: IoEvent) -> Result<()> { + self.buffer.enqueue_writer_event(event) + } + + pub fn dequeue_event(&self) -> Result<()> { + self.buffer.dequeue_writer_event() + } + + pub fn set_non_blocking(&self) { + self.buffer.set_non_blocking_write() + } + + pub fn set_blocking(&self) { + self.buffer.set_blocking_write() + } + + fn before_drop(&self) { + for (tid, event) in &*self.buffer.reader_wait_queue().lock().unwrap() { + match event { + IoEvent::Poll(_) | IoEvent::BlockingRead => { + notify_thread(*tid).unwrap(); + } + IoEvent::Epoll(epoll_file) => unimplemented!(), + IoEvent::BlockingWrite => unreachable!(), + } + } + } +} + +impl Drop for RingBufWriter { + fn drop(&mut self) { + debug!("writer drop"); + self.buffer.close_writer(); + if self.buffer.blocking_read() { + self.before_drop(); + } } } diff --git a/src/pal/src/ocalls/net.c b/src/pal/src/ocalls/net.c index badfed87..ba47ada7 100644 --- a/src/pal/src/ocalls/net.c +++ b/src/pal/src/ocalls/net.c @@ -1,6 +1,9 @@ +#include #include #include #include +#include +#include #include #include "ocalls.h" @@ -53,10 +56,35 @@ ssize_t occlum_ocall_recvmsg(int sockfd, return ret; } -int occlum_ocall_select(int nfds, - fd_set *readfds, - fd_set *writefds, - fd_set *exceptfds, - struct timeval *timeout) { - return select(nfds, readfds, writefds, exceptfds, timeout); +int occlum_ocall_poll(struct pollfd *fds, + nfds_t nfds, + struct timeval *timeout, + int efd) { + struct timeval start_tv, end_tv, elapsed_tv; + int real_timeout = (timeout == NULL) ? -1 : + (timeout->tv_sec * 1000 + timeout->tv_usec / 1000); + if (timeout != NULL) { + gettimeofday(&start_tv, NULL); + } + + int ret = poll(fds, nfds, real_timeout); + + if (timeout != NULL) { + gettimeofday(&end_tv, NULL); + timersub(&end_tv, &start_tv, &elapsed_tv); + if timercmp(timeout, &elapsed_tv, >= ) { + timersub(timeout, &elapsed_tv, timeout); + } else { + timeout->tv_sec = 0; + timeout->tv_usec = 0; + } + } + + int saved_errno = errno; + // clear the status of the eventfd + uint64_t u = 0; + read(efd, &u, sizeof(uint64_t)); + // restore the errno of poll + errno = saved_errno; + return ret; } diff --git a/test/pipe/main.c b/test/pipe/main.c index da372042..81815c56 100644 --- a/test/pipe/main.c +++ b/test/pipe/main.c @@ -1,6 +1,8 @@ +#include #include #include #include +#include #include #include #include @@ -69,7 +71,103 @@ int test_create_with_flags() { return 0; } -int test_read_write() { +int test_select_timeout() { + fd_set rfds; + + int pipe_fds[2]; + if (pipe(pipe_fds) < 0) { + THROW_ERROR("failed to create a pipe"); + } + + struct timeval tv = { .tv_sec = 1, .tv_usec = 0 }; + + FD_ZERO(&rfds); + FD_SET(pipe_fds[0], &rfds); + struct timeval tv_start, tv_end; + gettimeofday(&tv_start, NULL); + select(pipe_fds[0] + 1, &rfds, NULL, NULL, &tv); + gettimeofday(&tv_end, NULL); + double total_s = tv_end.tv_sec - tv_start.tv_sec; + if (total_s < 1) { + printf("time consumed is %f\n", + total_s + (double)(tv_end.tv_usec - tv_start.tv_usec) / 1000000); + THROW_ERROR("select timer does not work correctly"); + } + + free_pipe(pipe_fds); + return 0; +} + +int test_poll_timeout() { + // Start the timer + struct timeval tv_start, tv_end; + gettimeofday(&tv_start, NULL); + int fds[2]; + if (pipe(fds) < 0) { + THROW_ERROR("pipe failed"); + } + struct pollfd polls[] = { + { .fd = fds[0], .events = POLLOUT }, + { .fd = fds[1], .events = POLLIN } + }; + + poll(polls, 2, 1000); + // Stop the timer + gettimeofday(&tv_end, NULL); + double total_s = tv_end.tv_sec - tv_start.tv_sec; + if ((int)total_s < 1) { + printf("time consumed is %f\n", + total_s + (double)(tv_end.tv_usec - tv_start.tv_usec) / 1000000); + THROW_ERROR("poll timer does not work correctly"); + } + return 0; +} + +int test_select_no_timeout() { + fd_set wfds; + int ret = 0; + + int pipe_fds[2]; + if (pipe(pipe_fds) < 0) { + THROW_ERROR("failed to create a pipe"); + } + + FD_ZERO(&wfds); + FD_SET(pipe_fds[1], &wfds); + ret = select(pipe_fds[1] + 1, NULL, &wfds, NULL, NULL); + if (ret != 1) { + free_pipe(pipe_fds); + THROW_ERROR("select failed"); + } + + if (FD_ISSET(pipe_fds[1], &wfds) == 0) { + free_pipe(pipe_fds); + THROW_ERROR("bad select return"); + } + + free_pipe(pipe_fds); + return 0; +} + +int test_poll_no_timeout() { + int pipe_fds[2]; + if (pipe(pipe_fds) < 0) { + THROW_ERROR("failed to create a pipe"); + } + struct pollfd polls[] = { + { .fd = pipe_fds[0], .events = POLLIN }, + { .fd = pipe_fds[1], .events = POLLOUT }, + { .fd = pipe_fds[1], .events = POLLOUT }, + }; + int ret = poll(polls, 3, -1); + if (ret < 0) { THROW_ERROR("poll error"); } + + if (polls[0].revents != 0 || (polls[1].revents & POLLOUT) == 0 || + (polls[2].revents & POLLOUT) == 0 || ret != 2) { THROW_ERROR("wrong return events"); } + return 0; +} + +int test_select_read_write() { int pipe_fds[2]; if (pipe(pipe_fds) < 0) { THROW_ERROR("failed to create a pipe"); @@ -96,10 +194,19 @@ int test_read_write() { const char *expected_str = msg; size_t expected_len = strlen(expected_str); char actual_str[32] = {0}; - ssize_t actual_len; - do { - actual_len = read(pipe_rd_fd, actual_str, sizeof(actual_str) - 1); - } while (actual_len == 0); + fd_set rfds; + + FD_ZERO(&rfds); + FD_SET(pipe_fds[0], &rfds); + if (select(pipe_fds[0] + 1, &rfds, NULL, NULL, NULL) != 1) { + free_pipe(pipe_fds); + THROW_ERROR("select failed"); + } + + if (read(pipe_rd_fd, actual_str, sizeof(actual_str) - 1) < 0) { + THROW_ERROR("reading pipe failed"); + }; + if (strncmp(expected_str, actual_str, expected_len) != 0) { THROW_ERROR("received string is not as expected"); } @@ -120,7 +227,11 @@ static test_case_t test_cases[] = { TEST_CASE(test_fcntl_get_flags), TEST_CASE(test_fcntl_set_flags), TEST_CASE(test_create_with_flags), - TEST_CASE(test_read_write), + TEST_CASE(test_select_timeout), + TEST_CASE(test_poll_timeout), + TEST_CASE(test_select_no_timeout), + TEST_CASE(test_poll_no_timeout), + TEST_CASE(test_select_read_write), }; int main(int argc, const char *argv[]) { diff --git a/test/server/main.c b/test/server/main.c index c1fb58c1..d456d34e 100644 --- a/test/server/main.c +++ b/test/server/main.c @@ -306,7 +306,7 @@ int test_fcntl_setfl_and_getfl() { return ret; } -int test_poll_sockets() { +int test_poll_events_unchanged() { int socks[2], ret; socks[0] = socket(AF_INET, SOCK_STREAM, 0); socks[1] = socket(AF_INET, SOCK_STREAM, 0); @@ -329,13 +329,52 @@ int test_poll_sockets() { return 0; } +int test_poll() { + int child_pid = 0; + int client_fd = connect_with_child(8805, &child_pid); + if (client_fd < 0) { + THROW_ERROR("connect failed"); + } + + struct pollfd polls[] = { + { .fd = client_fd, .events = POLLIN } + }; + int ret = poll(polls, 1, -1); + if (ret <= 0) { + THROW_ERROR("poll error"); + } + + if (polls[0].revents & POLLIN) { + ssize_t count; + char buf[512]; + if ((count = read(client_fd, buf, sizeof buf)) != 0) { + if (strcmp(buf, DEFAULT_MSG) != 0) { + printf(buf); + THROW_ERROR("msg mismatched"); + } + } else { + THROW_ERROR("read error"); + } + } else { + THROW_ERROR("unexpected return events"); + } + + int status = 0; + if (wait4(child_pid, &status, 0, NULL) < 0) { + THROW_ERROR("failed to wait4 the child process"); + } + close(client_fd); + return 0; +} + static test_case_t test_cases[] = { TEST_CASE(test_read_write), TEST_CASE(test_send_recv), TEST_CASE(test_sendmsg_recvmsg), TEST_CASE(test_sendmsg_recvmsg_connectionless), TEST_CASE(test_fcntl_setfl_and_getfl), - TEST_CASE(test_poll_sockets), + TEST_CASE(test_poll), + TEST_CASE(test_poll_events_unchanged), }; int main(int argc, const char *argv[]) { diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index 8c830d92..2bd07ade 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -181,10 +182,32 @@ int test_socketpair_inter_process() { return test_connected_sockets_inter_process(create_connceted_sockets_default); } +int test_poll() { + int socks[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, socks) < 0) { + THROW_ERROR("socketpair failed"); + } + + write(socks[0], "not today\n", 10); + + struct pollfd polls[] = { + { .fd = socks[1], .events = POLLIN }, + { .fd = socks[0], .events = POLLOUT }, + }; + + int ret = poll(polls, 2, 5000); + if (ret <= 0) { THROW_ERROR("poll error"); } + if ((polls[0].revents & POLLOUT) && (polls[1].revents && POLLIN) == 0) { + THROW_ERROR("wrong return events"); + } + return 0; +} + static test_case_t test_cases[] = { TEST_CASE(test_unix_socket_inter_process), TEST_CASE(test_socketpair_inter_process), TEST_CASE(test_multiple_socketpairs), + TEST_CASE(test_poll), }; int main(int argc, const char *argv[]) {