Add socketpair syscall
1. Support creating socket pairs of the UNIX domain 2. Add test for socketpair in test/unix_socket 3. Refactor unix_socket test
This commit is contained in:
		
							parent
							
								
									9c4391b32d
								
							
						
					
					
						commit
						dc14f27a29
					
				| @ -2,7 +2,7 @@ use super::*; | ||||
| use alloc::prelude::ToString; | ||||
| use std::collections::btree_map::BTreeMap; | ||||
| use std::fmt; | ||||
| use std::sync::atomic::spin_loop_hint; | ||||
| use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; | ||||
| use std::sync::SgxMutex as Mutex; | ||||
| use util::ring_buf::{RingBuf, RingBufReader, RingBufWriter}; | ||||
| 
 | ||||
| @ -88,6 +88,9 @@ impl File for UnixSocketFile { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static SOCKETPAIR_NUM: AtomicUsize = AtomicUsize::new(0); | ||||
| const SOCK_PATH_PREFIX: &str = "socketpair_"; | ||||
| 
 | ||||
| impl UnixSocketFile { | ||||
|     pub fn new(socket_type: c_int, protocol: c_int) -> Result<Self> { | ||||
|         let inner = UnixSocket::new(socket_type, protocol)?; | ||||
| @ -123,6 +126,34 @@ impl UnixSocketFile { | ||||
|         let mut inner = self.inner.lock().unwrap(); | ||||
|         inner.poll() | ||||
|     } | ||||
| 
 | ||||
|     pub fn socketpair(socket_type: i32, protocol: i32) -> Result<(Self, Self)> { | ||||
|         let listen_socket = Self::new(socket_type, protocol)?; | ||||
|         let bound_path = listen_socket.bind_until_success(); | ||||
|         listen_socket.listen()?; | ||||
| 
 | ||||
|         let client_socket = Self::new(socket_type, protocol)?; | ||||
|         client_socket.connect(&bound_path)?; | ||||
| 
 | ||||
|         let accepted_socket = listen_socket.accept()?; | ||||
|         Ok((client_socket, accepted_socket)) | ||||
|     } | ||||
| 
 | ||||
|     fn bind_until_success(&self) -> String { | ||||
|         let mut path = SOCK_PATH_PREFIX.to_string(); | ||||
|         let mut index = SOCKETPAIR_NUM.load(Ordering::SeqCst); | ||||
|         path.push_str(&index.to_string()); | ||||
|         while self.bind(&path).is_err() { | ||||
|             if index == std::usize::MAX { | ||||
|                 SOCKETPAIR_NUM.store(0, Ordering::SeqCst); //flip SOCKETPAIR_NUM
 | ||||
|             } | ||||
|             index += 1; | ||||
|             path = SOCK_PATH_PREFIX.to_string(); | ||||
|             path.push_str(&index.to_string()); | ||||
|         } | ||||
|         SOCKETPAIR_NUM.fetch_max(index + 1, Ordering::SeqCst); | ||||
|         path | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Debug for UnixSocketFile { | ||||
|  | ||||
| @ -8,6 +8,7 @@ | ||||
| #![feature(range_contains)] | ||||
| #![feature(core_intrinsics)] | ||||
| #![feature(stmt_expr_attributes)] | ||||
| #![feature(atomic_min_max)] | ||||
| 
 | ||||
| #[macro_use] | ||||
| extern crate alloc; | ||||
|  | ||||
| @ -295,6 +295,12 @@ pub extern "C" fn dispatch_syscall( | ||||
|             arg4 as *mut libc::sockaddr, | ||||
|             arg5 as *mut libc::socklen_t, | ||||
|         ), | ||||
|         SYS_SOCKETPAIR => do_socketpair( | ||||
|             arg0 as c_int, | ||||
|             arg1 as c_int, | ||||
|             arg2 as c_int, | ||||
|             arg3 as *mut c_int, | ||||
|         ), | ||||
| 
 | ||||
|         _ => do_unknown(num, arg0, arg1, arg2, arg3, arg4, arg5), | ||||
|     }; | ||||
| @ -1026,7 +1032,7 @@ fn do_sched_setaffinity(pid: pid_t, cpusize: size_t, buf: *const c_uchar) -> Res | ||||
| 
 | ||||
| fn do_socket(domain: c_int, socket_type: c_int, protocol: c_int) -> Result<isize> { | ||||
|     info!( | ||||
|         "socket: domain: {}, socket_type: {}, protocol: {}", | ||||
|         "socket: domain: {}, socket_type: 0x{:x}, protocol: {}", | ||||
|         domain, socket_type, protocol | ||||
|     ); | ||||
| 
 | ||||
| @ -1320,6 +1326,46 @@ fn do_recvfrom( | ||||
|     Ok(ret as isize) | ||||
| } | ||||
| 
 | ||||
| fn do_socketpair( | ||||
|     domain: c_int, | ||||
|     socket_type: c_int, | ||||
|     protocol: c_int, | ||||
|     sv: *mut c_int, | ||||
| ) -> Result<isize> { | ||||
|     info!( | ||||
|         "socketpair: domain: {}, type:0x{:x}, protocol: {}", | ||||
|         domain, socket_type, protocol | ||||
|     ); | ||||
|     let mut sock_pair = unsafe { | ||||
|         check_mut_array(sv, 2)?; | ||||
|         std::slice::from_raw_parts_mut(sv as *mut u32, 2) | ||||
|     }; | ||||
| 
 | ||||
|     if (domain == libc::AF_UNIX) { | ||||
|         let (client_socket, server_socket) = | ||||
|             UnixSocketFile::socketpair(socket_type as i32, protocol as i32)?; | ||||
|         let current_ref = process::get_current(); | ||||
|         let mut proc = current_ref.lock().unwrap(); | ||||
|         sock_pair[0] = proc | ||||
|             .get_files() | ||||
|             .lock() | ||||
|             .unwrap() | ||||
|             .put(Arc::new(Box::new(client_socket)), false); | ||||
|         sock_pair[1] = proc | ||||
|             .get_files() | ||||
|             .lock() | ||||
|             .unwrap() | ||||
|             .put(Arc::new(Box::new(server_socket)), false); | ||||
| 
 | ||||
|         info!("socketpair: ({}, {})", sock_pair[0], sock_pair[1]); | ||||
|         Ok(0) | ||||
|     } else if (domain == libc::AF_TIPC) { | ||||
|         return_errno!(EAFNOSUPPORT, "cluster domain sockets not supported") | ||||
|     } else { | ||||
|         return_errno!(EAFNOSUPPORT, "domain not supported") | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn do_select( | ||||
|     nfds: c_int, | ||||
|     readfds: *mut libc::fd_set, | ||||
|  | ||||
| @ -9,106 +9,182 @@ | ||||
| #include <string.h> | ||||
| #include <spawn.h> | ||||
| 
 | ||||
| const char SOCK_PATH[] = "echo_socket"; | ||||
| #include "test.h" | ||||
| 
 | ||||
| int create_server_socket() { | ||||
| 	int fd = socket(AF_UNIX, SOCK_STREAM, 0); | ||||
| 	if (fd == -1) { | ||||
| 		printf("ERROR: failed to create a unix socket\n"); | ||||
| 		return -1; | ||||
| 	} | ||||
| #define ECHO_MSG "echo msg for unix_socket test" | ||||
| #define THROW_ERROR(msg)    do {                      \ | ||||
|     printf("\t\tERROR: %s in func %s at line %d of file %s\n",  \ | ||||
|            (msg), __func__, __LINE__, __FILE__);            \ | ||||
|     return -1;                                              \ | ||||
| } while(0) | ||||
| 
 | ||||
| 	struct sockaddr_un local; | ||||
| 	local.sun_family = AF_UNIX; | ||||
| 	strcpy(local.sun_path, SOCK_PATH); | ||||
| 	socklen_t len = strlen(local.sun_path) + sizeof(local.sun_family); | ||||
| 
 | ||||
| 	if (bind(fd, (struct sockaddr *)&local, len) == -1) { | ||||
| 		printf("ERROR: failed to bind\n"); | ||||
| 		return -1; | ||||
| 	} | ||||
| 
 | ||||
| 	if (listen(fd, 5) == -1) { | ||||
| 		printf("ERROR: failed to listen\n"); | ||||
| 		return -1; | ||||
| 	} | ||||
| 	return fd; | ||||
| } | ||||
| 
 | ||||
| int create_client_socket() { | ||||
| 	int fd = socket(AF_UNIX, SOCK_STREAM, 0); | ||||
| 	if (fd == -1) { | ||||
| 		printf("ERROR: failed to create a unix socket\n"); | ||||
| 		return -1; | ||||
| 	} | ||||
| 
 | ||||
| 	struct sockaddr_un remote; | ||||
| 	remote.sun_family = AF_UNIX; | ||||
| 	strcpy(remote.sun_path, SOCK_PATH); | ||||
| 	socklen_t len = strlen(remote.sun_path) + sizeof(remote.sun_family); | ||||
| 
 | ||||
| 	if (connect(fd, (struct sockaddr *)&remote, len) == -1) { | ||||
| 		printf("ERROR: failed to connect\n"); | ||||
| 		return -1; | ||||
| 	} | ||||
| 	return fd; | ||||
| } | ||||
| 
 | ||||
| int main(int argc, const char* argv[]) { | ||||
| 	int listen_fd = create_server_socket(); | ||||
| int create_connected_sockets(int *sockets, char *sock_path) { | ||||
|     int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0); | ||||
|     if (listen_fd == -1) { | ||||
| 		printf("ERROR: failed to create server socket\n"); | ||||
| 		return -1; | ||||
|         THROW_ERROR("failed to create a unix socket"); | ||||
|     } | ||||
| 
 | ||||
| 	int socket_rd_fd = create_client_socket(); | ||||
| 	if (socket_rd_fd == -1) { | ||||
| 		printf("ERROR: failed to create client socket\n"); | ||||
| 		return -1; | ||||
|     struct sockaddr_un addr; | ||||
|     memset(&addr, 0, sizeof(struct sockaddr_un)); //Clear structure
 | ||||
|     addr.sun_family = AF_UNIX; | ||||
|     strcpy(addr.sun_path, sock_path); | ||||
|     socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family); | ||||
|     if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) { | ||||
|         close(listen_fd); | ||||
|         THROW_ERROR("failed to bind"); | ||||
|     } | ||||
| 
 | ||||
| 	struct sockaddr_un remote; | ||||
| 	socklen_t len = sizeof(remote); | ||||
| 	int socket_wr_fd = accept(listen_fd, (struct sockaddr *)&remote, &len); | ||||
| 	if (socket_wr_fd == -1) { | ||||
| 		printf("ERROR: failed to accept socket\n"); | ||||
| 		return -1; | ||||
|     if (listen(listen_fd, 5) == -1) { | ||||
|         close(listen_fd); | ||||
|         THROW_ERROR("failed to listen"); | ||||
|     } | ||||
| 
 | ||||
| 	// The following is same as 'pipe'
 | ||||
|     int client_fd = socket(AF_UNIX, SOCK_STREAM, 0); | ||||
|     if (client_fd == -1) { | ||||
|         close(listen_fd); | ||||
|         THROW_ERROR("failed to create a unix socket"); | ||||
|     } | ||||
| 
 | ||||
|     posix_spawn_file_actions_t file_actions; | ||||
|     posix_spawn_file_actions_init(&file_actions); | ||||
|     posix_spawn_file_actions_adddup2(&file_actions, socket_wr_fd, STDOUT_FILENO); | ||||
|     posix_spawn_file_actions_addclose(&file_actions, socket_rd_fd); | ||||
|     if (connect(client_fd, (struct sockaddr *)&addr, addr_len) == -1) { | ||||
|         close(listen_fd); | ||||
|         close(client_fd); | ||||
|         THROW_ERROR("failed to connect"); | ||||
|     } | ||||
| 
 | ||||
|     const char* msg = "Echo!\n"; | ||||
|     int accepted_fd = accept(listen_fd, (struct sockaddr *)&addr, &addr_len); | ||||
|     if (accepted_fd == -1) { | ||||
|         close(listen_fd); | ||||
|         close(client_fd); | ||||
|         THROW_ERROR("failed to accept socket"); | ||||
|     } | ||||
| 
 | ||||
|     sockets[0] = client_fd; | ||||
|     sockets[1] = accepted_fd; | ||||
|     close(listen_fd); | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| int create_connceted_sockets_default(int *sockets) { | ||||
|     return create_connected_sockets(sockets, "unix_socket_default_path"); | ||||
| } | ||||
| 
 | ||||
| int verify_child_echo(int *connected_sockets) { | ||||
|     const char* child_prog = "/bin/hello_world"; | ||||
|     const char* child_argv[3] = { child_prog, msg, NULL }; | ||||
|     const char* child_argv[3] = { child_prog, ECHO_MSG, NULL }; | ||||
|     int child_pid; | ||||
|     posix_spawn_file_actions_t file_actions; | ||||
| 
 | ||||
|     posix_spawn_file_actions_init(&file_actions); | ||||
|     posix_spawn_file_actions_adddup2(&file_actions, connected_sockets[0], STDOUT_FILENO); | ||||
|     posix_spawn_file_actions_addclose(&file_actions, connected_sockets[1]); | ||||
| 
 | ||||
|     if (posix_spawn(&child_pid, child_prog, &file_actions, | ||||
|             NULL, (char*const*)child_argv, NULL) < 0) { | ||||
|         printf("ERROR: failed to spawn a child process\n"); | ||||
|         return -1; | ||||
|         THROW_ERROR("failed to spawn a child process"); | ||||
|     } | ||||
|     close(socket_wr_fd); | ||||
| 
 | ||||
|     const char* expected_str = msg; | ||||
|     size_t expected_len = strlen(expected_str); | ||||
|     char actual_str[32] = {0}; | ||||
|     ssize_t actual_len; | ||||
|     //TODO: implement blocking read
 | ||||
|     do { | ||||
|         actual_len = read(socket_rd_fd, actual_str, sizeof(actual_str) - 1); | ||||
|         actual_len = read(connected_sockets[1], actual_str, 32); | ||||
|     } while (actual_len == 0); | ||||
|     if (strncmp(expected_str, actual_str, expected_len) != 0) { | ||||
|         printf("ERROR: received string is not as expected\n"); | ||||
|         return -1; | ||||
|     if (strncmp(actual_str, ECHO_MSG, sizeof(ECHO_MSG) - 1) != 0) { | ||||
|         printf("data read is :%s\n", actual_str); | ||||
|         THROW_ERROR("received string is not as expected"); | ||||
|     } | ||||
| 
 | ||||
|     int status = 0; | ||||
|     if (wait4(child_pid, &status, 0, NULL) < 0) { | ||||
|         printf("ERROR: failed to wait4 the child process\n"); | ||||
|         return -1; | ||||
|         THROW_ERROR("failed to wait4 the child process"); | ||||
|     } | ||||
| 
 | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| int verify_connection(int src_sock, int dest_sock) { | ||||
|     char buf[1024]; | ||||
|     int i; | ||||
|     for (i = 0; i < 100; i++) { | ||||
|         if (write(src_sock, ECHO_MSG, sizeof(ECHO_MSG)) < 0) { | ||||
|             THROW_ERROR("writing server message"); | ||||
|         } | ||||
| 
 | ||||
|         if (read(dest_sock, buf, 1024) < 0) { | ||||
|             THROW_ERROR("reading server message"); | ||||
|         } | ||||
| 
 | ||||
|         if (strncmp(buf, ECHO_MSG, sizeof(ECHO_MSG)) != 0) { | ||||
|             THROW_ERROR("msg received mismatch"); | ||||
|         } | ||||
|     } | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| //this value should not be too large as one pair consumes 2MB memory
 | ||||
| #define PAIR_NUM 15 | ||||
| 
 | ||||
| int test_multiple_socketpairs() { | ||||
|     int sockets[PAIR_NUM][2]; | ||||
|     int i; | ||||
|     int ret = 0; | ||||
| 
 | ||||
|     for(i = 0; i < PAIR_NUM; i++) { | ||||
|         if (socketpair(AF_UNIX, SOCK_STREAM, 0, sockets[i]) < 0) { | ||||
|             THROW_ERROR("opening stream socket pair"); | ||||
|         } | ||||
| 
 | ||||
|         if(verify_connection(sockets[i][0], sockets[i][1]) < 0) { | ||||
|             ret = -1; | ||||
|             goto cleanup; | ||||
|         } | ||||
|          | ||||
|         if(verify_connection(sockets[i][1], sockets[i][0]) < 0) { | ||||
|             ret = -1; | ||||
|             goto cleanup; | ||||
|         } | ||||
|     } | ||||
|     i--; | ||||
| cleanup: | ||||
|     for(; i >= 0; i--){ | ||||
|         close(sockets[i][0]); | ||||
|         close(sockets[i][1]); | ||||
|     } | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| int socketpair_default(int *sockets) { | ||||
|     return socketpair(AF_UNIX, SOCK_STREAM, 0, sockets); | ||||
| } | ||||
| 
 | ||||
| typedef int(*create_connection_func_t)(int *); | ||||
| int test_connected_sockets_inter_process(create_connection_func_t fn) { | ||||
|     int ret = 0; | ||||
|     int sockets[2]; | ||||
|     if (fn(sockets) < 0) | ||||
|         return -1; | ||||
| 
 | ||||
|     ret = verify_child_echo(sockets); | ||||
| 
 | ||||
|     close(sockets[0]); | ||||
|     close(sockets[1]); | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| int test_unix_socket_inter_process() { | ||||
|     return test_connected_sockets_inter_process(socketpair_default); | ||||
| } | ||||
| 
 | ||||
| int test_socketpair_inter_process() { | ||||
|     return test_connected_sockets_inter_process(create_connceted_sockets_default); | ||||
| } | ||||
| 
 | ||||
| static test_case_t test_cases[] = { | ||||
|     TEST_CASE(test_unix_socket_inter_process), | ||||
|     TEST_CASE(test_socketpair_inter_process), | ||||
|     TEST_CASE(test_multiple_socketpairs), | ||||
| }; | ||||
| 
 | ||||
| int main(int argc, const char* argv[]) { | ||||
|     return test_suite_run(test_cases, ARRAY_SIZE(test_cases)); | ||||
| } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user