From 71c4937b45cf6080acafa4de6b6650267e242ec9 Mon Sep 17 00:00:00 2001 From: "Hui, Chunyang" Date: Tue, 9 Aug 2022 08:54:02 +0000 Subject: [PATCH] Fix listening socket epoll_wait not waken by connect --- .../net/socket/unix/stream/address_space.rs | 14 +- src/libos/src/net/socket/unix/stream/file.rs | 10 +- .../src/net/socket/unix/stream/stream.rs | 30 ++++- test/unix_socket/Makefile | 2 +- test/unix_socket/main.c | 121 ++++++++++++++++-- 5 files changed, 157 insertions(+), 20 deletions(-) diff --git a/src/libos/src/net/socket/unix/stream/address_space.rs b/src/libos/src/net/socket/unix/stream/address_space.rs index 7cbfc14e..551a8ca6 100644 --- a/src/libos/src/net/socket/unix/stream/address_space.rs +++ b/src/libos/src/net/socket/unix/stream/address_space.rs @@ -1,4 +1,5 @@ use super::endpoint::Endpoint; +use super::endpoint::RelayNotifier; use super::stream::Listener; use super::*; use std::collections::btree_map::BTreeMap; @@ -49,13 +50,22 @@ impl AddressSpace { } } - pub fn add_listener(&self, addr: &Addr, capacity: usize, nonblocking: bool) -> Result<()> { + pub(super) fn add_listener( + &self, + addr: &Addr, + capacity: usize, + nonblocking: bool, + notifier: Arc, + ) -> Result<()> { let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?; let mut space = self.get_space(addr); if let Some(option) = space.get(&key) { if option.is_none() { - space.insert(key, Some(Arc::new(Listener::new(capacity, nonblocking)?))); + space.insert( + key, + Some(Arc::new(Listener::new(capacity, nonblocking, notifier)?)), + ); Ok(()) } else { return_errno!(EINVAL, "the socket is already listened"); diff --git a/src/libos/src/net/socket/unix/stream/file.rs b/src/libos/src/net/socket/unix/stream/file.rs index 7e7c644e..da3cbcc7 100644 --- a/src/libos/src/net/socket/unix/stream/file.rs +++ b/src/libos/src/net/socket/unix/stream/file.rs @@ -1,3 +1,4 @@ +use super::address_space::ADDRESS_SPACE; use super::stream::Status; use super::*; use fs::{AccessMode, File, FileRef, IoEvents, IoNotifier, IoctlCmd, StatusFlags}; @@ -100,9 +101,12 @@ impl File for Stream { // linux return value Status::Idle(info) => IoEvents::OUT | IoEvents::HUP, Status::Connected(endpoint) => endpoint.poll(), - Status::Listening(_) => { - warn!("poll is not fully implemented for the listener socket"); - IoEvents::empty() + Status::Listening(addr) => { + if let Some(listener) = ADDRESS_SPACE.get_listener_ref(addr) { + listener.poll_new() + } else { + IoEvents::empty() + } } } } diff --git a/src/libos/src/net/socket/unix/stream/stream.rs b/src/libos/src/net/socket/unix/stream/stream.rs index 45c1684e..d7e1d16f 100644 --- a/src/libos/src/net/socket/unix/stream/stream.rs +++ b/src/libos/src/net/socket/unix/stream/stream.rs @@ -121,7 +121,12 @@ impl Stream { Status::Idle(info) => { if let Some(addr) = info.addr() { warn!("addr = {:?}", addr); - ADDRESS_SPACE.add_listener(addr, capacity, info.nonblocking())?; + ADDRESS_SPACE.add_listener( + addr, + capacity, + info.nonblocking(), + self.notifier.clone(), + )?; *inner = Status::Listening(addr.clone()); } else { return_errno!(EINVAL, "the socket is not bound"); @@ -165,6 +170,11 @@ impl Stream { self.notifier.observe_endpoint(&end_self); + // Notify listener for this event + if let Some(listener) = ADDRESS_SPACE.get_listener_ref(addr) { + listener.notifier.notifier().broadcast(&IoEvents::IN); + } + *inner = Status::Connected(end_self); Ok(()) } @@ -322,16 +332,22 @@ impl Info { /// ECONNREFUSED rather than block when the channel is full. pub struct Listener { channel: RwLock>, + notifier: Arc, } impl Listener { - pub fn new(capacity: usize, nonblocking: bool) -> Result { + pub(super) fn new( + capacity: usize, + nonblocking: bool, + notifier: Arc, + ) -> Result { let channel = Channel::new(capacity)?; channel.producer().set_nonblocking(true); channel.consumer().set_nonblocking(nonblocking); Ok(Self { channel: RwLock::new(channel), + notifier, }) } @@ -391,4 +407,14 @@ impl Listener { let channel = self.channel.read().unwrap(); channel.shutdown(); } + + pub fn poll_new(&self) -> IoEvents { + let mut events = IoEvents::empty(); + let channel = self.channel.read().unwrap(); + let item_num = channel.items_to_consume(); + if item_num > 0 { + events |= IoEvents::IN; + } + events + } } diff --git a/test/unix_socket/Makefile b/test/unix_socket/Makefile index 8c5a2fb4..3ec3f733 100644 --- a/test/unix_socket/Makefile +++ b/test/unix_socket/Makefile @@ -1,5 +1,5 @@ include ../test_common.mk EXTRA_C_FLAGS := -Wno-incompatible-pointer-types-discards-qualifiers -EXTRA_LINK_FLAGS := +EXTRA_LINK_FLAGS := -lpthread BIN_ARGS := diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index 6bea81c1..156127c5 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "test.h" @@ -21,7 +23,7 @@ int create_connected_sockets(int *sockets, char *sock_path) { } struct sockaddr_un addr; - memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure + memset(&addr, 0, sizeof(struct sockaddr_un)); // Clear structure addr.sun_family = AF_UNIX; strcpy(addr.sun_path, sock_path); socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1; @@ -73,7 +75,7 @@ int create_connected_sockets_then_rename(int *sockets) { } struct sockaddr_un addr; - memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure + memset(&addr, 0, sizeof(struct sockaddr_un)); // Clear structure addr.sun_family = AF_UNIX; strcpy(addr.sun_path, socket_original_path); @@ -107,7 +109,7 @@ int create_connected_sockets_then_rename(int *sockets) { } struct sockaddr_un addr_client; - memset(&addr_client, 0, sizeof(struct sockaddr_un)); //Clear structure + memset(&addr_client, 0, sizeof(struct sockaddr_un)); // Clear structure addr_client.sun_family = AF_UNIX; strcpy(addr_client.sun_path, "/proc/self/root"); strcat(addr_client.sun_path, socket_ready_path); @@ -135,7 +137,7 @@ int create_connected_sockets_then_rename(int *sockets) { int verify_child_echo(int *connected_sockets) { const char *child_prog = "/bin/hello_world"; - const char *child_argv[3] = { child_prog, ECHO_MSG, NULL }; + const char *child_argv[3] = {child_prog, ECHO_MSG, NULL}; int child_pid; posix_spawn_file_actions_t file_actions; @@ -149,7 +151,7 @@ int verify_child_echo(int *connected_sockets) { } struct pollfd polls[] = { - { .fd = connected_sockets[1], .events = POLLIN }, + {.fd = connected_sockets[1], .events = POLLIN}, }; // Test for blocking poll, poll will be only interrupted by sigchld @@ -199,7 +201,7 @@ int verify_connection(int src_sock, int dest_sock) { return 0; } -//this value should not be too large as one pair consumes 2MB memory +// this value should not be too large as one pair consumes 2MB memory #define PAIR_NUM 15 int test_multiple_socketpairs() { @@ -235,7 +237,7 @@ int socketpair_default(int *sockets) { return socketpair(AF_UNIX, SOCK_STREAM, 0, sockets); } -typedef int(*create_connection_func_t)(int *); +typedef int (*create_connection_func_t)(int *); int test_connected_sockets_inter_process(create_connection_func_t fn) { int ret = 0; int sockets[2]; @@ -274,12 +276,14 @@ int test_poll() { } struct pollfd polls[] = { - { .fd = socks[0], .events = POLLOUT }, - { .fd = socks[1], .events = POLLIN }, + {.fd = socks[0], .events = POLLOUT}, + {.fd = socks[1], .events = POLLIN}, }; int ret = poll(polls, 2, 5000); - if (ret <= 0) { THROW_ERROR("poll error"); } + if (ret <= 0) { + THROW_ERROR("poll error"); + } if (((polls[0].revents & POLLOUT) && (polls[1].revents & POLLIN)) == 0) { printf("%d %d\n", polls[0].revents, polls[1].revents); THROW_ERROR("wrong return events"); @@ -295,7 +299,7 @@ int test_getname() { } struct sockaddr_un addr = {0}; - memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure + memset(&addr, 0, sizeof(struct sockaddr_un)); // Clear structure addr.sun_family = AF_UNIX; strcpy(addr.sun_path, name); socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1; @@ -330,7 +334,7 @@ int test_ioctl_fionread() { } const char *child_prog = "/bin/hello_world"; - const char *child_argv[3] = { child_prog, ECHO_MSG, NULL }; + const char *child_argv[3] = {child_prog, ECHO_MSG, NULL}; int child_pid; posix_spawn_file_actions_t file_actions; @@ -369,6 +373,98 @@ int test_ioctl_fionread() { return 0; } +void *client_routine(void *arg) { + // Sleep a while before connect + sleep(3); + printf("sleep is done\n"); + + char *sock_path = "/tmp/test.sock"; + struct sockaddr_un addr; + memset(&addr, 0, sizeof(struct sockaddr_un)); // Clear structure + addr.sun_family = AF_UNIX; + strcpy(addr.sun_path, sock_path); + socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1; + + int client_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (client_fd == -1) { + printf("failed to create a unix socket\n"); + return NULL; + } + + if (connect(client_fd, (struct sockaddr *)&addr, addr_len) == -1) { + close(client_fd); + printf("failed to connect\n"); + return NULL; + } + + printf("connect success\n"); + return NULL; +} + +int test_epoll_wait() { + char *sock_path = "/tmp/test.sock"; + int ret; + struct epoll_event event; + uint32_t interest_events = EPOLLIN | EPOLLOUT; + struct epoll_event polled_events; + pthread_t client_tid; + struct sockaddr_un addr; + + int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (listen_fd == -1) { + THROW_ERROR("failed to create a unix socket"); + } + + memset(&addr, 0, sizeof(struct sockaddr_un)); // Clear structure + addr.sun_family = AF_UNIX; + strcpy(addr.sun_path, sock_path); + socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1; + unlink(addr.sun_path); + if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) { + close(listen_fd); + THROW_ERROR("failed to bind"); + } + + if (listen(listen_fd, 5) != 0) { + THROW_ERROR("server_routine, error in listen"); + } + + int ep_fd = epoll_create1(0); + if (ep_fd < 0) { + THROW_ERROR("failed to create an epoll"); + } + + event.events = interest_events; + event.data.u32 = listen_fd; + ret = epoll_ctl(ep_fd, EPOLL_CTL_ADD, listen_fd, &event); + if (ret < 0) { + THROW_ERROR("failed to do epoll ctl"); + } + + if (pthread_create(&client_tid, NULL, client_routine, NULL)) { + THROW_ERROR("Failure creating client thread"); + } + + // wait infinitely + ret = epoll_wait(ep_fd, &polled_events, 1, -1); + if (ret != 1) { + THROW_ERROR("failed to do epoll wait"); + } + + if (polled_events.data.u32 != listen_fd) { + THROW_ERROR("epoll wait returned wrong fd"); + } + + int newsock = accept(listen_fd, (struct sockaddr *)&addr, &addr_len); + if (newsock == -1) { + THROW_ERROR("server_routine, error in accept"); + } + + printf("accept done, new socket = %d\n", newsock); + pthread_join(client_tid, NULL); + return 0; +} + static test_case_t test_cases[] = { TEST_CASE(test_unix_socket_inter_process), TEST_CASE(test_socketpair_inter_process), @@ -377,6 +473,7 @@ static test_case_t test_cases[] = { TEST_CASE(test_getname), TEST_CASE(test_ioctl_fionread), TEST_CASE(test_unix_socket_rename), + TEST_CASE(test_epoll_wait), }; int main(int argc, const char *argv[]) {