From 2ea7fc1ad674807ba0a03376ee029795c5ff27a2 Mon Sep 17 00:00:00 2001 From: He Sun Date: Wed, 30 Dec 2020 22:10:09 +0800 Subject: [PATCH] Add Unix socket support for poll A relay notifier that observes the underlying endpoint is added as the notifier for the socket. It broadcasts to its observers when either end of the channel has IoEvents. --- .../src/net/socket/unix/stream/endpoint.rs | 83 ++++++++++++++++++- src/libos/src/net/socket/unix/stream/file.rs | 20 +++-- .../src/net/socket/unix/stream/stream.rs | 21 ++++- test/unix_socket/main.c | 26 +++--- 4 files changed, 132 insertions(+), 18 deletions(-) diff --git a/src/libos/src/net/socket/unix/stream/endpoint.rs b/src/libos/src/net/socket/unix/stream/endpoint.rs index fa3ab370..e54fa470 100644 --- a/src/libos/src/net/socket/unix/stream/endpoint.rs +++ b/src/libos/src/net/socket/unix/stream/endpoint.rs @@ -1,6 +1,9 @@ use super::*; -use alloc::sync::{Arc, Weak}; +use events::{Event, EventFilter, Notifier, Observer}; use fs::channel::{Channel, Consumer, Producer}; +use fs::{IoEvents, IoNotifier}; +use std::any::Any; +use std::sync::{Arc, Weak}; pub type Endpoint = Arc; @@ -100,6 +103,36 @@ impl Inner { Ok(()) } + pub fn poll(&self) -> IoEvents { + let mut events = IoEvents::empty(); + let reader_events = self.reader.poll(); + let writer_events = self.writer.poll(); + + if reader_events.contains(IoEvents::HUP) || self.reader.is_self_shutdown() { + events |= IoEvents::RDHUP; + if writer_events.contains(IoEvents::ERR) || self.writer.is_self_shutdown() { + events |= IoEvents::HUP; + } + } + + events |= (reader_events & IoEvents::IN) | (writer_events & IoEvents::OUT); + events + } + + pub(self) fn register_relay_notifier(&self, observer: &Arc) { + self.reader.notifier().register( + Arc::downgrade(observer) as Weak>, + None, + None, + ); + + self.writer.notifier().register( + Arc::downgrade(observer) as Weak>, + None, + None, + ); + } + fn is_connected(&self) -> bool { self.peer.upgrade().is_some() } @@ -108,3 +141,51 @@ impl Inner { // TODO: Add SO_SNDBUF and SO_RCVBUF to set/getsockopt to dynamcally change the size. // This value is got from /proc/sys/net/core/rmem_max and wmem_max that are same on linux. pub const DEFAULT_BUF_SIZE: usize = 208 * 1024; + +/// An observer used to observe both reader and writer of the endpoint. It also contains a +/// notifier that relays the notification of the endpoint. +pub(super) struct RelayNotifier { + notifier: IoNotifier, + endpoint: SgxMutex>, +} + +impl RelayNotifier { + pub fn new() -> Self { + let notifier = IoNotifier::new(); + let endpoint = SgxMutex::new(None); + Self { notifier, endpoint } + } + + pub fn notifier(&self) -> &IoNotifier { + &self.notifier + } + + pub fn observe_endpoint(self: &Arc, endpoint: &Endpoint) { + endpoint.register_relay_notifier(self); + *self.endpoint.lock().unwrap() = Some(endpoint.clone()); + } +} + +impl Observer for RelayNotifier { + fn on_event(&self, event: &IoEvents, _metadata: &Option>) { + let endpoint = self.endpoint.lock().unwrap(); + // Only endpoint can broadcast events + + let mut event = event.clone(); + // The event of the channel should not be broadcasted directly to socket. + // The event transformation should be consistant with poll. + if event.contains(IoEvents::HUP) { + event -= IoEvents::HUP; + event |= IoEvents::RDHUP; + } + + if event.contains(IoEvents::ERR) { + event -= IoEvents::ERR; + event |= IoEvents::HUP; + } + + // A notifier can only have events after observe_endpoint + self.notifier() + .broadcast(&(endpoint.as_ref().unwrap().poll() & event)); + } +} diff --git a/src/libos/src/net/socket/unix/stream/file.rs b/src/libos/src/net/socket/unix/stream/file.rs index a39ed671..20722554 100644 --- a/src/libos/src/net/socket/unix/stream/file.rs +++ b/src/libos/src/net/socket/unix/stream/file.rs @@ -1,6 +1,6 @@ use super::stream::Status; use super::*; -use fs::{AccessMode, File, FileRef, IoctlCmd, StatusFlags}; +use fs::{AccessMode, File, FileRef, IoEvents, IoNotifier, IoctlCmd, StatusFlags}; use std::any::Any; impl File for Stream { @@ -90,10 +90,20 @@ impl File for Stream { Ok(()) } - fn poll(&self) -> Result { - warn!("poll is not supported for unix_socket"); - let events = PollEventFlags::empty(); - Ok(events) + fn poll_new(&self) -> IoEvents { + match &*self.inner() { + // linux return value + Status::Idle(info) => IoEvents::OUT | IoEvents::HUP, + Status::Connected(endpoint) => endpoint.poll(), + Status::Listening(_) => { + warn!("poll is not fully implemented for the listener socket"); + IoEvents::empty() + } + } + } + + fn notifier(&self) -> Option<&IoNotifier> { + Some(&self.notifier.notifier()) } fn as_any(&self) -> &dyn Any { diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs index 66423389..b5923cc9 100644 --- a/src/libos/src/net/socket/unix/stream/stream.rs +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -1,10 +1,12 @@ use super::address_space::ADDRESS_SPACE; -use super::endpoint::{end_pair, Endpoint}; +use super::endpoint::{end_pair, Endpoint, RelayNotifier}; use super::*; -use alloc::sync::Arc; +use events::{Event, EventFilter, Notifier, Observer}; use fs::channel::Channel; +use fs::IoEvents; use std::fmt; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; /// SOCK_STREAM Unix socket. It has three statuses: unconnected, listening and connected. When a /// socket is created, it is in unconnected status. It will transfer to listening after listen is @@ -13,6 +15,9 @@ use std::sync::atomic::{AtomicBool, Ordering}; /// will not transfer to other statuses. pub struct Stream { inner: SgxMutex, + // Use the internal notifier of RelayNotifier as the notifier of stream socket. It relays the + // events of the endpoint, too. + pub(super) notifier: Arc, } impl Stream { @@ -21,19 +26,26 @@ impl Stream { inner: SgxMutex::new(Status::Idle(Info::new( flags.contains(FileFlags::SOCK_NONBLOCK), ))), + notifier: Arc::new(RelayNotifier::new()), } } pub fn socketpair(flags: FileFlags) -> Result<(Self, Self)> { let nonblocking = flags.contains(FileFlags::SOCK_NONBLOCK); let (end_a, end_b) = end_pair(nonblocking)?; + let notifier_a = Arc::new(RelayNotifier::new()); + let notifier_b = Arc::new(RelayNotifier::new()); + notifier_a.observe_endpoint(&end_a); + notifier_b.observe_endpoint(&end_b); let socket_a = Self { inner: SgxMutex::new(Status::Connected(end_a)), + notifier: notifier_a, }; let socket_b = Self { inner: SgxMutex::new(Status::Connected(end_b)), + notifier: notifier_b, }; Ok((socket_a, socket_b)) @@ -135,6 +147,8 @@ impl Stream { _ => e, })?; + self.notifier.observe_endpoint(&end_self); + *inner = Status::Connected(end_self); Ok(()) } @@ -149,6 +163,8 @@ impl Stream { Status::Listening(addr) => { let endpoint = ADDRESS_SPACE.pop_incoming(&addr)?; endpoint.set_nonblocking(flags.contains(FileFlags::SOCK_NONBLOCK)); + let notifier = Arc::new(RelayNotifier::new()); + notifier.observe_endpoint(&endpoint); let peer_addr = endpoint.peer_addr(); @@ -157,6 +173,7 @@ impl Stream { Ok(( Self { inner: SgxMutex::new(Status::Connected(endpoint)), + notifier: notifier, }, peer_addr, )) diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index 82692fa8..ebd6bb83 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -79,12 +79,19 @@ int verify_child_echo(int *connected_sockets) { THROW_ERROR("failed to spawn a child process"); } + struct pollfd polls[] = { + { .fd = connected_sockets[1], .events = POLLIN }, + }; + + // Test for blocking poll, poll will be only interrupted by sigchld + // if socket does not support waking up a sleeping poller + int ret = poll(polls, 1, -1); + if (ret < 0) { + THROW_ERROR("failed to poll"); + } + char actual_str[32] = {0}; - ssize_t actual_len; - //TODO: implement blocking read - do { - actual_len = read(connected_sockets[1], actual_str, 32); - } while (actual_len == 0); + read(connected_sockets[1], actual_str, 32); if (strncmp(actual_str, ECHO_MSG, sizeof(ECHO_MSG) - 1) != 0) { printf("data read is :%s\n", actual_str); THROW_ERROR("received string is not as expected"); @@ -191,13 +198,14 @@ int test_poll() { write(socks[0], "not today\n", 10); struct pollfd polls[] = { - { .fd = socks[1], .events = POLLIN }, { .fd = socks[0], .events = POLLOUT }, + { .fd = socks[1], .events = POLLIN }, }; int ret = poll(polls, 2, 5000); if (ret <= 0) { THROW_ERROR("poll error"); } - if ((polls[0].revents & POLLOUT) && (polls[1].revents && POLLIN) == 0) { + if (((polls[0].revents & POLLOUT) && (polls[1].revents & POLLIN)) == 0) { + printf("%d %d\n", polls[0].revents, polls[1].revents); THROW_ERROR("wrong return events"); } return 0; @@ -241,9 +249,7 @@ static test_case_t test_cases[] = { TEST_CASE(test_unix_socket_inter_process), TEST_CASE(test_socketpair_inter_process), TEST_CASE(test_multiple_socketpairs), - // TODO: recover the test after the unix sockets are rewritten by using - // the new event subsystem - //TEST_CASE(test_poll), + TEST_CASE(test_poll), TEST_CASE(test_getname), };