diff --git a/src/libos/src/fs/unix_socket.rs b/src/libos/src/fs/unix_socket.rs index dd984cbd..7f7b46a9 100644 --- a/src/libos/src/fs/unix_socket.rs +++ b/src/libos/src/fs/unix_socket.rs @@ -2,7 +2,7 @@ use super::*; use alloc::prelude::ToString; use std::collections::btree_map::BTreeMap; use std::fmt; -use std::sync::atomic::spin_loop_hint; +use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; use std::sync::SgxMutex as Mutex; use util::ring_buf::{RingBuf, RingBufReader, RingBufWriter}; @@ -88,6 +88,9 @@ impl File for UnixSocketFile { } } +static SOCKETPAIR_NUM: AtomicUsize = AtomicUsize::new(0); +const SOCK_PATH_PREFIX: &str = "socketpair_"; + impl UnixSocketFile { pub fn new(socket_type: c_int, protocol: c_int) -> Result { let inner = UnixSocket::new(socket_type, protocol)?; @@ -123,6 +126,34 @@ impl UnixSocketFile { let mut inner = self.inner.lock().unwrap(); inner.poll() } + + pub fn socketpair(socket_type: i32, protocol: i32) -> Result<(Self, Self)> { + let listen_socket = Self::new(socket_type, protocol)?; + let bound_path = listen_socket.bind_until_success(); + listen_socket.listen()?; + + let client_socket = Self::new(socket_type, protocol)?; + client_socket.connect(&bound_path)?; + + let accepted_socket = listen_socket.accept()?; + Ok((client_socket, accepted_socket)) + } + + fn bind_until_success(&self) -> String { + let mut path = SOCK_PATH_PREFIX.to_string(); + let mut index = SOCKETPAIR_NUM.load(Ordering::SeqCst); + path.push_str(&index.to_string()); + while self.bind(&path).is_err() { + if index == std::usize::MAX { + SOCKETPAIR_NUM.store(0, Ordering::SeqCst); //flip SOCKETPAIR_NUM + } + index += 1; + path = SOCK_PATH_PREFIX.to_string(); + path.push_str(&index.to_string()); + } + SOCKETPAIR_NUM.fetch_max(index + 1, Ordering::SeqCst); + path + } } impl Debug for UnixSocketFile { diff --git a/src/libos/src/lib.rs b/src/libos/src/lib.rs index 89833180..8a968ad8 100644 --- a/src/libos/src/lib.rs +++ b/src/libos/src/lib.rs @@ -8,6 +8,7 @@ #![feature(range_contains)] #![feature(core_intrinsics)] #![feature(stmt_expr_attributes)] +#![feature(atomic_min_max)] #[macro_use] extern crate alloc; diff --git a/src/libos/src/syscall/mod.rs b/src/libos/src/syscall/mod.rs index ce64829c..bc1cdea9 100644 --- a/src/libos/src/syscall/mod.rs +++ b/src/libos/src/syscall/mod.rs @@ -295,6 +295,12 @@ pub extern "C" fn dispatch_syscall( arg4 as *mut libc::sockaddr, arg5 as *mut libc::socklen_t, ), + SYS_SOCKETPAIR => do_socketpair( + arg0 as c_int, + arg1 as c_int, + arg2 as c_int, + arg3 as *mut c_int, + ), _ => do_unknown(num, arg0, arg1, arg2, arg3, arg4, arg5), }; @@ -1026,7 +1032,7 @@ fn do_sched_setaffinity(pid: pid_t, cpusize: size_t, buf: *const c_uchar) -> Res fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result { info!( - "socket: domain: {}, socket_type: {}, protocol: {}", + "socket: domain: {}, socket_type: 0x{:x}, protocol: {}", domain, socket_type, protocol ); @@ -1320,6 +1326,46 @@ fn do_recvfrom( Ok(ret as isize) } +fn do_socketpair( + domain: c_int, + socket_type: c_int, + protocol: c_int, + sv: *mut c_int, +) -> Result { + info!( + "socketpair: domain: {}, type:0x{:x}, protocol: {}", + domain, socket_type, protocol + ); + let mut sock_pair = unsafe { + check_mut_array(sv, 2)?; + std::slice::from_raw_parts_mut(sv as *mut u32, 2) + }; + + if (domain == libc::AF_UNIX) { + let (client_socket, server_socket) = + UnixSocketFile::socketpair(socket_type as i32, protocol as i32)?; + let current_ref = process::get_current(); + let mut proc = current_ref.lock().unwrap(); + sock_pair[0] = proc + .get_files() + .lock() + .unwrap() + .put(Arc::new(Box::new(client_socket)), false); + sock_pair[1] = proc + .get_files() + .lock() + .unwrap() + .put(Arc::new(Box::new(server_socket)), false); + + info!("socketpair: ({}, {})", sock_pair[0], sock_pair[1]); + Ok(0) + } else if (domain == libc::AF_TIPC) { + return_errno!(EAFNOSUPPORT, "cluster domain sockets not supported") + } else { + return_errno!(EAFNOSUPPORT, "domain not supported") + } +} + fn do_select( nfds: c_int, readfds: *mut libc::fd_set, diff --git a/test/unix_socket/main.c b/test/unix_socket/main.c index 717dbec5..07e25ecb 100644 --- a/test/unix_socket/main.c +++ b/test/unix_socket/main.c @@ -9,106 +9,182 @@ #include #include -const char SOCK_PATH[] = "echo_socket"; +#include "test.h" -int create_server_socket() { - int fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (fd == -1) { - printf("ERROR: failed to create a unix socket\n"); - return -1; - } +#define ECHO_MSG "echo msg for unix_socket test" +#define THROW_ERROR(msg) do { \ + printf("\t\tERROR: %s in func %s at line %d of file %s\n", \ + (msg), __func__, __LINE__, __FILE__); \ + return -1; \ +} while(0) - struct sockaddr_un local; - local.sun_family = AF_UNIX; - strcpy(local.sun_path, SOCK_PATH); - socklen_t len = strlen(local.sun_path) + sizeof(local.sun_family); +int create_connected_sockets(int *sockets, char *sock_path) { + int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (listen_fd == -1) { + THROW_ERROR("failed to create a unix socket"); + } - if (bind(fd, (struct sockaddr *)&local, len) == -1) { - printf("ERROR: failed to bind\n"); - return -1; - } + struct sockaddr_un addr; + memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure + addr.sun_family = AF_UNIX; + strcpy(addr.sun_path, sock_path); + socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family); + if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) { + close(listen_fd); + THROW_ERROR("failed to bind"); + } - if (listen(fd, 5) == -1) { - printf("ERROR: failed to listen\n"); - return -1; - } - return fd; + if (listen(listen_fd, 5) == -1) { + close(listen_fd); + THROW_ERROR("failed to listen"); + } + + int client_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (client_fd == -1) { + close(listen_fd); + THROW_ERROR("failed to create a unix socket"); + } + + if (connect(client_fd, (struct sockaddr *)&addr, addr_len) == -1) { + close(listen_fd); + close(client_fd); + THROW_ERROR("failed to connect"); + } + + int accepted_fd = accept(listen_fd, (struct sockaddr *)&addr, &addr_len); + if (accepted_fd == -1) { + close(listen_fd); + close(client_fd); + THROW_ERROR("failed to accept socket"); + } + + sockets[0] = client_fd; + sockets[1] = accepted_fd; + close(listen_fd); + return 0; } -int create_client_socket() { - int fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (fd == -1) { - printf("ERROR: failed to create a unix socket\n"); - return -1; - } - - struct sockaddr_un remote; - remote.sun_family = AF_UNIX; - strcpy(remote.sun_path, SOCK_PATH); - socklen_t len = strlen(remote.sun_path) + sizeof(remote.sun_family); - - if (connect(fd, (struct sockaddr *)&remote, len) == -1) { - printf("ERROR: failed to connect\n"); - return -1; - } - return fd; +int create_connceted_sockets_default(int *sockets) { + return create_connected_sockets(sockets, "unix_socket_default_path"); } -int main(int argc, const char* argv[]) { - int listen_fd = create_server_socket(); - if (listen_fd == -1) { - printf("ERROR: failed to create server socket\n"); - return -1; - } - - int socket_rd_fd = create_client_socket(); - if (socket_rd_fd == -1) { - printf("ERROR: failed to create client socket\n"); - return -1; - } - - struct sockaddr_un remote; - socklen_t len = sizeof(remote); - int socket_wr_fd = accept(listen_fd, (struct sockaddr *)&remote, &len); - if (socket_wr_fd == -1) { - printf("ERROR: failed to accept socket\n"); - return -1; - } - - // The following is same as 'pipe' - - posix_spawn_file_actions_t file_actions; - posix_spawn_file_actions_init(&file_actions); - posix_spawn_file_actions_adddup2(&file_actions, socket_wr_fd, STDOUT_FILENO); - posix_spawn_file_actions_addclose(&file_actions, socket_rd_fd); - - const char* msg = "Echo!\n"; +int verify_child_echo(int *connected_sockets) { const char* child_prog = "/bin/hello_world"; - const char* child_argv[3] = { child_prog, msg, NULL }; + const char* child_argv[3] = { child_prog, ECHO_MSG, NULL }; int child_pid; + posix_spawn_file_actions_t file_actions; + + posix_spawn_file_actions_init(&file_actions); + posix_spawn_file_actions_adddup2(&file_actions, connected_sockets[0], STDOUT_FILENO); + posix_spawn_file_actions_addclose(&file_actions, connected_sockets[1]); + if (posix_spawn(&child_pid, child_prog, &file_actions, NULL, (char*const*)child_argv, NULL) < 0) { - printf("ERROR: failed to spawn a child process\n"); - return -1; + THROW_ERROR("failed to spawn a child process"); } - close(socket_wr_fd); - const char* expected_str = msg; - size_t expected_len = strlen(expected_str); char actual_str[32] = {0}; ssize_t actual_len; + //TODO: implement blocking read do { - actual_len = read(socket_rd_fd, actual_str, sizeof(actual_str) - 1); + actual_len = read(connected_sockets[1], actual_str, 32); } while (actual_len == 0); - if (strncmp(expected_str, actual_str, expected_len) != 0) { - printf("ERROR: received string is not as expected\n"); - return -1; + if (strncmp(actual_str, ECHO_MSG, sizeof(ECHO_MSG) - 1) != 0) { + printf("data read is :%s\n", actual_str); + THROW_ERROR("received string is not as expected"); } int status = 0; if (wait4(child_pid, &status, 0, NULL) < 0) { - printf("ERROR: failed to wait4 the child process\n"); - return -1; + THROW_ERROR("failed to wait4 the child process"); + } + + return 0; +} + +int verify_connection(int src_sock, int dest_sock) { + char buf[1024]; + int i; + for (i = 0; i < 100; i++) { + if (write(src_sock, ECHO_MSG, sizeof(ECHO_MSG)) < 0) { + THROW_ERROR("writing server message"); + } + + if (read(dest_sock, buf, 1024) < 0) { + THROW_ERROR("reading server message"); + } + + if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) { + THROW_ERROR("msg received mismatch"); + } } return 0; } + +//this value should not be too large as one pair consumes 2MB memory +#define PAIR_NUM 15 + +int test_multiple_socketpairs() { + int sockets[PAIR_NUM][2]; + int i; + int ret = 0; + + for(i = 0; i < PAIR_NUM; i++) { + if (socketpair(AF_UNIX, SOCK_STREAM, 0, sockets[i]) < 0) { + THROW_ERROR("opening stream socket pair"); + } + + if(verify_connection(sockets[i][0], sockets[i][1]) < 0) { + ret = -1; + goto cleanup; + } + + if(verify_connection(sockets[i][1], sockets[i][0]) < 0) { + ret = -1; + goto cleanup; + } + } + i--; +cleanup: + for(; i >= 0; i--){ + close(sockets[i][0]); + close(sockets[i][1]); + } + return ret; +} + +int socketpair_default(int *sockets) { + return socketpair(AF_UNIX, SOCK_STREAM, 0, sockets); +} + +typedef int(*create_connection_func_t)(int *); +int test_connected_sockets_inter_process(create_connection_func_t fn) { + int ret = 0; + int sockets[2]; + if (fn(sockets) < 0) + return -1; + + ret = verify_child_echo(sockets); + + close(sockets[0]); + close(sockets[1]); + return ret; +} + +int test_unix_socket_inter_process() { + return test_connected_sockets_inter_process(socketpair_default); +} + +int test_socketpair_inter_process() { + return test_connected_sockets_inter_process(create_connceted_sockets_default); +} + +static test_case_t test_cases[] = { + TEST_CASE(test_unix_socket_inter_process), + TEST_CASE(test_socketpair_inter_process), + TEST_CASE(test_multiple_socketpairs), +}; + +int main(int argc, const char* argv[]) { + return test_suite_run(test_cases, ARRAY_SIZE(test_cases)); +}