From 0f789b49bcdc7d7faea1d5d4cf071338acb7ef1a Mon Sep 17 00:00:00 2001 From: "Hui, Chunyang" Date: Mon, 14 Nov 2022 09:27:23 +0000 Subject: [PATCH] Fix exit_group not interrupt wait4 --- src/libos/src/process/do_exit.rs | 9 +++ src/libos/src/process/do_wait4.rs | 13 ++-- src/libos/src/process/wait.rs | 17 +++++- test/client/main.c | 34 +++++++++++ test/server/Makefile | 2 +- test/server/main.c | 98 +++++++++++++++++++++++++++++++ 6 files changed, 165 insertions(+), 8 deletions(-) diff --git a/src/libos/src/process/do_exit.rs b/src/libos/src/process/do_exit.rs index 4dee8c77..4db94c1a 100644 --- a/src/libos/src/process/do_exit.rs +++ b/src/libos/src/process/do_exit.rs @@ -19,6 +19,15 @@ pub fn do_exit_group(status: i32, curr_user_ctxt: &mut CpuContext) -> Result Result<(p // without risking missing events from the process's children. drop(process_inner); // Wait until a child has interesting events - let zombie_pid = waiter.sleep_until_woken_with_result(); - - let mut process_inner = process.inner(); - let exit_status = free_zombie_child(process_inner, zombie_pid); - Ok((zombie_pid, exit_status)) + if let Some(zombie_pid) = waiter.sleep_until_woken_with_result() { + let mut process_inner = process.inner(); + let exit_status = free_zombie_child(process_inner, zombie_pid); + Ok((zombie_pid, exit_status)) + } else { + // The wait is interrupted + return_errno!(EINTR, "wait is interrupted and not get any children"); + } } fn free_zombie_child(mut parent_inner: SgxMutexGuard, zombie_pid: pid_t) -> i32 { diff --git a/src/libos/src/process/wait.rs b/src/libos/src/process/wait.rs index b6d42543..2c206e03 100644 --- a/src/libos/src/process/wait.rs +++ b/src/libos/src/process/wait.rs @@ -51,14 +51,14 @@ where self.inner.lock().unwrap().data } - pub fn sleep_until_woken_with_result(self) -> R { + pub fn sleep_until_woken_with_result(self) -> Option { while !self.inner.lock().unwrap().is_woken { unsafe { wait_event(self.thread); } } - self.inner.lock().unwrap().result.unwrap() + self.inner.lock().unwrap().result } } @@ -114,4 +114,17 @@ where set_event(del_waiter.thread); 1 } + + pub fn del_and_wake_all_waiters(&mut self) -> usize { + let mut waiters = &mut self.waiters; + let ret = waiters.len(); + waiters.drain(..).for_each(|waiter| { + let mut waiter_inner = waiter.inner.lock().unwrap(); + waiter_inner.is_woken = true; + waiter_inner.result = None; + set_event(waiter.thread); + }); + + ret + } } diff --git a/test/client/main.c b/test/client/main.c index 2eccbda5..dbf34cd9 100644 --- a/test/client/main.c +++ b/test/client/main.c @@ -9,6 +9,7 @@ #include #include #include +#include #include "test.h" @@ -159,6 +160,28 @@ int client_connectionless_sendmsg(char *buf) { return ret; } +int blocking_recvfrom(int server_fd, char *buf, int buf_size) { + int flags = fcntl(server_fd, F_GETFL, 0); + if (flags == -1) { + THROW_ERROR("fnctl failed"); + } + flags = flags & ~O_NONBLOCK; + fcntl(server_fd, F_SETFL, flags); + if (flags == -1) { + THROW_ERROR("fnctl failed"); + } + + // wait for server to exit and the remote end is closed + sleep(1); + printf("client blocking recvfrom\n"); + int ret = recvfrom(server_fd, buf, buf_size, 0, NULL, 0); + if (ret >= 0 || errno != ECONNRESET) { + THROW_ERROR("recvfrom failed"); + } + + return 0; +} + int main(int argc, const char *argv[]) { if (argc != 3) { THROW_ERROR("usage: ./client \n"); @@ -190,6 +213,17 @@ int main(int argc, const char *argv[]) { case 8804: ret = client_connectionless_sendmsg(DEFAULT_MSG); break; + case 8888: + neogotiate_msg(server_fd, buf, buf_size); + ret = client_sendmsg(server_fd, buf); + if (ret < 0) { + THROW_ERROR("client sendmsg failed"); + } + ret = blocking_recvfrom(server_fd, buf, buf_size); + if (ret < 0) { + THROW_ERROR("client blocking recvfrom failed"); + } + break; default: ret = client_send(server_fd, DEFAULT_MSG); } diff --git a/test/server/Makefile b/test/server/Makefile index 9e1b6dec..5c1ee8c1 100644 --- a/test/server/Makefile +++ b/test/server/Makefile @@ -1,5 +1,5 @@ include ../test_common.mk EXTRA_C_FLAGS := -EXTRA_LINK_FLAGS := +EXTRA_LINK_FLAGS := -lpthread BIN_ARGS := diff --git a/test/server/main.c b/test/server/main.c index 2b61b60a..d578753f 100644 --- a/test/server/main.c +++ b/test/server/main.c @@ -1,3 +1,4 @@ +#define _GNU_SOURCE #include #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include "test.h" @@ -209,6 +211,14 @@ int wait_for_child_exit(int child_pid) { return 0; } +static void *thread_wait_func(void *_arg) { + pid_t *client_pid = _arg; + + waitpid(*client_pid, NULL, 0); + + return NULL; +} + int test_read_write() { int ret = 0; int child_pid = 0; @@ -403,6 +413,93 @@ int test_poll() { return 0; } +// This is a testcase mocking pyspark exit procedure. Client process is receiving and blocking. +// One of server process' child thread waits for the client to exit and the main thread calls exit_group. +static int test_exit_group() { + int port = 8888; + int pipes[2]; + int ret = 0; + int listen_fd = socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd < 0) { + THROW_ERROR("create socket error"); + } + + ret = pipe2(pipes, 0); + if (ret < 0) { + THROW_ERROR("error happens"); + } + + printf("pipe fd = %d, %d\n", pipes[0], pipes[1]); + + int child_pid = vfork(); + if (child_pid == 0) { + ret = close(pipes[1]); + if (ret < 0) { + THROW_ERROR("error happens"); + } + ret = dup2(pipes[0], 0); + if (ret < 0) { + THROW_ERROR("error happens"); + } + + ret = close(pipes[0]); + if (ret < 0) { + THROW_ERROR("error happens"); + } + + char port_string[8]; + sprintf(port_string, "%d", port); + char *client_argv[] = {"client", "127.0.0.1", port_string, NULL}; + printf("exec child\n"); + execve("/bin/client", client_argv, NULL); + } + + printf("return to parent\n"); + close(pipes[0]); + + int reuse = 1; + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { + THROW_ERROR("setsockopt port to reuse failed"); + } + + struct sockaddr_in servaddr; + memset(&servaddr, 0, sizeof(servaddr)); + servaddr.sin_family = AF_INET; + servaddr.sin_addr.s_addr = htonl(INADDR_ANY); + servaddr.sin_port = htons(port); + ret = bind(listen_fd, (struct sockaddr *) &servaddr, sizeof(servaddr)); + if (ret < 0) { + close(listen_fd); + THROW_ERROR("bind socket failed"); + } + + ret = listen(listen_fd, 5); + if (ret < 0) { + close(listen_fd); + THROW_ERROR("listen socket error"); + } + + int connected_fd = accept(listen_fd, (struct sockaddr *) NULL, NULL); // 4 + if (connected_fd < 0) { + close(listen_fd); + THROW_ERROR("accept socket error"); + } + + if (neogotiate_msg(connected_fd) < 0) { + THROW_ERROR("neogotiate failed"); + } + + pthread_t tid; + ret = pthread_create(&tid, NULL, thread_wait_func, &child_pid); + if (ret != 0) { + THROW_ERROR("create child error"); + } + + // Wait a while here for client to call recvfrom and blocking + sleep(2); + return 0; +} + static test_case_t test_cases[] = { TEST_CASE(test_read_write), TEST_CASE(test_send_recv), @@ -414,6 +511,7 @@ static test_case_t test_cases[] = { TEST_CASE(test_fcntl_setfl_and_getfl), TEST_CASE(test_poll), TEST_CASE(test_poll_events_unchanged), + TEST_CASE(test_exit_group), }; int main(int argc, const char *argv[]) {